org.deeplearning4j.util.StringGrid.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.util.StringGrid.java

Source

/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.util;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;

import static org.deeplearning4j.berkeley.StringUtils.splitOnCharWithQuoting;

/**
 * String matrix
 * @author Adam Gibson
 *
 */
public class StringGrid extends ArrayList<List<String>> {

    private static final long serialVersionUID = 4702427632483221813L;
    private String sep;
    private int numColumns = -1;
    private static final Logger log = LoggerFactory.getLogger(StringGrid.class);
    public final static String NONE = "NONE";

    public StringGrid(StringGrid grid) {
        this.sep = grid.sep;
        this.numColumns = grid.numColumns;
        addAll(grid);
        fillOut();
    }

    public StringGrid(String sep, int numColumns) {
        this(sep, new ArrayList<String>());
        this.numColumns = numColumns;
        fillOut();
    }

    public int getNumColumns() {
        return numColumns;
    }

    private void fillOut() {
        for (List<String> list : this) {
            if (list.size() < numColumns) {
                int diff = numColumns - list.size();
                for (int i = 0; i < diff; i++) {
                    list.add(NONE);
                }
            }
        }
    }

    public static StringGrid fromFile(String file, String sep) throws IOException {
        List<String> read = FileUtils.readLines(new File(file));
        if (read.isEmpty())
            throw new IllegalStateException("Nothing to read; file is empty");

        return new StringGrid(sep, read);

    }

    public static StringGrid fromInput(InputStream from, String sep) throws IOException {
        List<String> read = IOUtils.readLines(from);
        if (read.isEmpty())
            throw new IllegalStateException("Nothing to read; file is empty");

        return new StringGrid(sep, read);

    }

    public StringGrid(String sep, Collection<String> data) {
        super();
        this.sep = sep;
        List<String> list = new ArrayList<>(data);
        for (int i = 0; i < list.size(); i++) {
            String line = list.get(i).trim();
            //text delimiter
            if (line.indexOf('\"') > 0) {
                Counter<Character> counter = new Counter<>();
                for (int j = 0; j < line.length(); j++) {
                    counter.incrementCount(line.charAt(j), 1.0);
                }
                if (counter.getCount('"') > 1) {
                    String[] split = splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\');
                    add(new ArrayList<>(Arrays.asList(split)));
                } else {
                    List<String> row = new ArrayList<>(
                            Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\')));
                    if (numColumns < 0)
                        numColumns = row.size();
                    else if (row.size() != numColumns)
                        log.warn("Row " + i + " had invalid number of columns  line was " + line);
                    add(row);
                }

            } else {
                List<String> row = new ArrayList<>(
                        Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\')));
                if (numColumns < 0)
                    numColumns = row.size();
                else if (row.size() != numColumns) {
                    log.warn("Could not add " + line);
                }
                add(row);
            }

        }
        fillOut();
    }

    /**
     * Removes all rows with a column of NONE
     * @param column the column to remove by
     */
    public void removeRowsWithEmptyColumn(int column) {
        List<List<String>> remove = new ArrayList<>();
        for (List<String> list : this) {
            if (list.get(column).equals(NONE))
                remove.add(list);
        }
        removeAll(remove);
    }

    public void head(int num) {
        if (num >= size())
            num = size();
        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < num; i++) {
            builder.append(get(i) + "\n");
        }

        log.info(builder.toString());
    }

    /**
     * Removes the specified columns from the grid
     * @param columns the columns to remove
     */
    public void removeColumns(Integer... columns) {
        if (columns.length < 1)
            throw new IllegalArgumentException("Columns must contain at least one column");
        List<Integer> removeOrder = Arrays.asList(columns);
        //put them in the right order for removing
        Collections.sort(removeOrder);
        for (List<String> list : this) {
            List<String> remove = new ArrayList<>();
            for (int i = 0; i < columns.length; i++) {
                remove.add(list.get(columns[i]));
            }
            list.removeAll(remove);
        }

    }

    /**
     * Removes all rows with a column of missingValue
     * @param column he column to remove by
     * @param missingValue the missingValue sentinel value
     */
    public void removeRowsWithEmptyColumn(int column, String missingValue) {
        List<List<String>> remove = new ArrayList<>();
        for (List<String> list : this) {
            if (list.get(column).equals(missingValue))
                remove.add(list);
        }
        removeAll(remove);
    }

    public List<List<String>> getRowsWithColumnValues(Collection<String> values, int column) {
        List<List<String>> ret = new ArrayList<>();
        for (List<String> val : this) {
            if (values.contains(val.get(column)))
                ret.add(val);
        }
        return ret;
    }

    public void sortColumnsByWordLikelihoodIncluded(final int column) {
        final Counter<String> counter = new Counter<>();
        List<String> col = getColumn(column);

        for (String s : col) {
            StringTokenizer tokenizer = new StringTokenizer(s);
            while (tokenizer.hasMoreTokens()) {
                counter.incrementCount(tokenizer.nextToken(), 1.0);
            }
        }

        if (counter.totalCount() <= 0.0) {
            log.warn("Unable to calculate probability; nothing found");
            return;
        }

        //laplace smoothing
        counter.incrementAll(counter.keySet(), 1.0);
        Set<String> remove = new HashSet<>();
        for (String key : counter.keySet())
            if (key.length() < 2 || key.matches("[a-z]+"))
                remove.add(key);
        for (String key : remove)
            counter.removeKey(key);

        counter.pruneKeysBelowThreshold(4.0);

        final double totalCount = counter.totalCount();

        Collections.sort(this, new Comparator<List<String>>() {

            @Override
            public int compare(List<String> o1, List<String> o2) {
                double c1 = sumOverTokens(counter, o1.get(column), totalCount);
                double c2 = sumOverTokens(counter, o2.get(column), totalCount);
                return Double.compare(c1, c2);
            }

        });
    }

    /* Return the log sum of the column relative to the word frequencies (equivalent to the probability in log space */
    private double sumOverTokens(Counter<String> counter, String column, double totalCount) {
        StringTokenizer tokenizer = new StringTokenizer(column);
        double count = 0;
        while (tokenizer.hasMoreTokens())
            count += Math.log(counter.getCount(column) / totalCount);

        return count;
    }

    public StringCluster clusterColumn(int column) {
        return new StringCluster(getColumn(column));

    }

    public void dedupeByClusterAll() {
        for (int i = 0; i < size(); i++)
            dedupeByCluster(i);
    }

    /**
     * Deduplicate based on the column clustering signature
     * @param column
     */
    public void dedupeByCluster(int column) {
        StringCluster cluster = clusterColumn(column);
        System.out.println(cluster.get("family mcdonalds restaurant"));
        System.out.println(cluster.get("family mcdonalds restaurants"));
        List<Map<String, Integer>> list2 = cluster.getClusters();
        for (int i = 0; i < list2.size(); i++) {
            if (list2.get(i).size() > 1) {
                System.out.println(list2.get(i));
            }
        }
        FingerPrintKeyer keyer = new FingerPrintKeyer();
        Set<Integer> alreadyDeDupped = new HashSet<>();
        for (int i = 0; i < size(); i++) {
            String key = keyer.key(get(i).get(column));
            Map<String, Integer> map = cluster.get(key);
            if (map != null && map.size() > 1) {
                List<Integer> list = filterRowsByColumn(column, map.keySet());
                //deduplication to do
                if (list.size() > 1)
                    modifyRows(alreadyDeDupped, column, list, map);
            }

        }
    }

    /**
     * Cleans up the rows specified that haven't already been deduplified
     * @param alreadyDeDupped the already dedupped rows
     * @param column the column to homogenize
     * @param rows the rows to preProcess
     * @param cluster the cluster of values
     */
    private void modifyRows(Set<Integer> alreadyDeDupped, Integer column, List<Integer> rows,
            Map<String, Integer> cluster) {
        String chosenKey = null;
        Integer max = null;
        for (Map.Entry<String, Integer> entry : cluster.entrySet()) {
            String key = entry.getKey();
            int value = entry.getValue();
            StringTokenizer val = new StringTokenizer(key);
            List<String> list = new ArrayList<>();
            boolean allLower = true;

            outer: while (val.hasMoreTokens()) {
                String token = val.nextToken();
                //weird capitalization
                if (token.length() >= 3 && token.matches("[A-Z]+")) {
                    continue outer;
                }
                list.add(token);
            }

            for (String s : list) {
                allLower = allLower && s.matches("[a-z]+");

            }
            if (allLower) {
                continue;
            }

            //not a proper name
            if (list.get(list.size() - 1).toLowerCase().equals("the")) {
                continue;
            }
            //first selection that's valid or count is higher
            if (max == null || (!allLower && value > max)) {
                max = value;
                chosenKey = key;
            }
        }

        //wtf is wrong with you people?
        if (chosenKey == null) {
            //getFromOrigin the max value of the cluster
            String max2 = maximalValue(cluster);
            StringTokenizer val = new StringTokenizer(max2);
            List<String> list = new ArrayList<>();
            while (val.hasMoreTokens()) {
                String token = val.nextToken();
                //weird capitalization
                if (token.length() >= 3 && token.matches("[A-Z]+")) {
                    token = token.charAt(0) + token.substring(1).toLowerCase();
                }
                list.add(token);
            }

            boolean allLower = true;
            for (String s : list)
                allLower = allLower && s.matches("[a-z]+");
            if (list.get(list.size() - 1).toLowerCase().equals("the")) {
                max2 = max2.replaceAll("^[Tt]he", "");

            }
            if (allLower)
                max2 = StringUtils.capitalize(max2);
            chosenKey = max2;

        }

        for (Integer i2 : rows) {
            //row already processed
            if (!alreadyDeDupped.contains(i2)) {
                disambiguateRow(i2, column, chosenKey);

            }
        }
    }

    private String maximalValue(Map<String, Integer> map) {
        Counter<String> counter = new Counter<>();
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            counter.incrementCount(entry.getKey(), entry.getValue());
        }
        return counter.argMax();
    }

    private void disambiguateRow(Integer row, Integer column, String chosenValue) {
        System.out.println("SETTING " + row + " column " + column + " to " + chosenValue);
        get(row).set(column, chosenValue);
    }

    public List<Integer> filterRowsByColumn(int column, Collection<String> values) {
        List<Integer> list = new ArrayList<>();
        for (int i = 0; i < size(); i++) {
            if (values.contains(get(i).get(column)))
                list.add(i);
        }
        return list;
    }

    public void sortBy(final int column) {
        Collections.sort(this, new Comparator<List<String>>() {

            @Override
            public int compare(List<String> o1, List<String> o2) {
                return o1.get(column).compareTo(o2.get(column));
            }

        });
    }

    public List<String> toLines() {
        List<String> lines = new ArrayList<>();
        for (List<String> list : this) {
            StringBuilder sb = new StringBuilder();
            for (String s : list) {
                sb.append(s.replaceAll(sep, " "));
                sb.append(sep);
            }

            lines.add(sb.toString().substring(0, sb.lastIndexOf(sep)));
        }
        return lines;
    }

    public void swap(int column1, int column2) {
        List<String> col1 = getColumn(column1);
        List<String> col2 = getColumn(column2);
        for (int i = 0; i < size(); i++) {
            get(i).set(column1, col2.get(i));
            get(i).set(column2, col1.get(i));
        }
    }

    public void merge(int column1, int column2) {
        checkInvalidColumn(column1);
        checkInvalidColumn(column2);

        if (column1 != column2)
            for (List<String> list : this) {
                StringBuilder sb = new StringBuilder();
                sb.append(list.get(column1));
                sb.append(list.get(column2));
                list.set(Math.min(column1, column2), sb.toString().replaceAll("\"", "").replace(sep, " "));
                list.remove(Math.max(column1, column2));
            }
        numColumns--;
    }

    public StringGrid getAllWithSimilarity(double threshold, int firstColumn, int secondColumn) {
        for (int column : new int[] { firstColumn, secondColumn })
            checkInvalidColumn(column);
        StringGrid grid = new StringGrid(sep, numColumns);
        for (List<String> list : this) {
            double sim = MathUtils.stringSimilarity(list.get(firstColumn), list.get(secondColumn));
            if (sim >= threshold)
                grid.addRow(list);
        }
        return grid;

    }

    public void writeLinesTo(String path) throws IOException {
        FileUtils.writeLines(new File(path), toLines());
    }

    public void fillDown(String value, int column) {
        checkInvalidColumn(column);
        for (List<String> list : this)
            list.set(column, value);
    }

    public StringGrid select(int column, String value) {
        StringGrid grid = new StringGrid(sep, numColumns);
        for (int i = 0; i < size(); i++) {
            List<String> row = get(i);
            if (row.get(column).equals(value)) {
                grid.addRow(row);
            }

        }
        return grid;
    }

    public void split(int column, String sepBy) {
        List<String> col = getColumn(column);
        int validate = -1;
        Set<String> remove = new HashSet<>();
        for (int i = 0; i < col.size(); i++) {
            String s = col.get(i);
            String[] split2 = StringUtils.splitOnCharWithQuoting(s, sepBy.charAt(0), '"', '\\');
            if (validate < 0)
                validate = split2.length;
            else if (validate != split2.length) {
                log.warn("Row " + get(i) + " will be invalid after split; removing");
                remove.add(s);
            }
        }

        for (String s : remove) {
            StringGrid grid = select(column, s);
            removeAll(grid);
        }
        Map<Integer, List<String>> replace = new HashMap<>();
        for (int i = 0; i < size(); i++) {
            List<String> list = get(i);
            List<String> newList = new ArrayList<>();
            String split = list.get(column);
            String[] split2 = StringUtils.splitOnCharWithQuoting(split, sepBy.charAt(0), '"', '\\');
            //add right next to where column was split
            for (int j = 0; j < list.size(); j++) {
                if (j == column)
                    for (String s : split2)
                        newList.add(s);

                else
                    newList.add(list.get(j));
            }
            replace.put(i, newList);

        }

        //prevent concurrent modification
        for (Map.Entry<Integer, List<String>> entry : replace.entrySet()) {
            set(entry.getKey(), entry.getValue());
        }
    }

    public void filterBySimilarity(double threshold, int firstColumn, int secondColumn) {
        for (int column : new int[] { firstColumn, secondColumn })
            checkInvalidColumn(column);
        List<List<String>> remove = new ArrayList<>();
        for (List<String> list : this) {
            double sim = MathUtils.stringSimilarity(list.get(firstColumn), list.get(secondColumn));
            if (sim < threshold)
                remove.add(list);
        }
        removeAll(remove);
    }

    public void prependToEach(String prepend, int toColumn) {
        for (List<String> row : this) {
            String currVal = row.get(toColumn);
            row.set(toColumn, prepend + currVal);
        }
    }

    public void appendToEach(String append, int toColumn) {
        for (List<String> row : this) {
            String currVal = row.get(toColumn);
            row.set(toColumn, currVal + append);
        }
    }

    public void addColumn(List<String> column) {
        if (column.size() != this.size())
            throw new IllegalArgumentException("Unable to add column; not enough rows");
        for (int i = 0; i < size(); i++) {
            get(i).add(column.get(i));
        }
    }

    /**
     * Combine the column based on a template and a number of template variable
     * columns. Note that this will also collapse the columns specified (removing them)
     *
     * @param templateColumn the column with the template ( uses printf style templating)
     * @param paramColumns the columns with template variables
     */
    public void combineColumns(int templateColumn, Integer[] paramColumns) {
        for (List<String> list : this) {
            List<String> format = new ArrayList<>();
            for (int j : paramColumns)
                format.add(list.get(j));

            list.set(templateColumn,
                    String.format(list.get(templateColumn), (Object[]) format.toArray(new String[] {})));
            //collapse columns
            list.removeAll(format);
        }
    }

    /**
     * Combine the column based on a template and a number of template variable
     * columns. Note that this will also collapse the columns specified (removing them)
     *
     * @param templateColumn the column with the template ( uses printf style templating)
     * @param paramColumns the columns with template variables
     */
    public void combineColumns(int templateColumn, int[] paramColumns) {
        for (List<String> list : this) {
            List<String> format = new ArrayList<>();
            for (int j : paramColumns)
                format.add(list.get(j));

            list.set(templateColumn,
                    String.format(list.get(templateColumn), (Object[]) format.toArray(new String[] {})));
            //collapse columns
            list.removeAll(format);
        }
    }

    public void addRow(List<String> row) {
        if (row.isEmpty()) {
            log.warn("Unable to add empty row");
        }

        else if (!isEmpty() && row.size() != get(0).size()) {
            log.warn("Unable to add row; not the same number of columns");
        } else
            add(row);
    }

    public Map<String, List<List<String>>> mapByPrimaryKey(int columnKey) {
        Map<String, List<List<String>>> map = new HashMap<>();
        for (List<String> line : this) {
            String val = line.get(columnKey);
            List<List<String>> get = map.get(val);
            if (get == null) {
                get = new ArrayList<>();
                map.put(val, get);
            }
            get.add(new ArrayList<>(Arrays.asList(sep)));

        }
        return map;
    }

    public List<String> getRow(int row) {
        checkInvalidRow(row);
        return new ArrayList<>(get(row));
    }

    public List<String> getColumn(int column) {
        checkInvalidColumn(column);
        List<String> ret = new ArrayList<>();
        for (List<String> list : this) {
            ret.add(list.get(column));
        }
        return ret;
    }

    private void checkInvalidRow(int row) {
        if (row < 0 || row >= size())
            throw new IllegalArgumentException("Row does not exist");
    }

    private void checkInvalidColumn(int column) {
        if (column < 0 || column >= numColumns)
            throw new IllegalArgumentException("Invalid column " + column);
    }

    public StringGrid getRowsWithDuplicateValuesInColumn(int column) {
        checkInvalidColumn(column);
        StringGrid grid = new StringGrid(sep, numColumns);
        List<String> columns = getColumn(column);
        Counter<String> counter = new Counter<>();
        for (String val : columns)
            counter.incrementCount(val, 1.0);
        counter.pruneKeysBelowThreshold(2.0);
        Set<String> keys = counter.keySet();
        for (List<String> row : this) {
            for (String key : keys)
                if (row.get(column).equals(key))
                    grid.addRow(row);

        }
        return grid;
    }

    public StringGrid getRowWithOnlyOneOccurrence(int column) {
        checkInvalidColumn(column);
        StringGrid grid = new StringGrid(sep, numColumns);
        List<String> columns = getColumn(column);
        Counter<String> counter = new Counter<>();
        for (String val : columns)
            counter.incrementCount(val, 1.0);

        Set<String> keys = new HashSet<>(counter.keySet());
        for (String key : keys) {
            if (counter.getCount(key) > 1) {
                counter.removeKey(key);
            }
        }

        for (List<String> row : this) {
            for (String key : keys)
                if (row.get(column).equals(key))
                    grid.addRow(row);

        }
        return grid;
    }

    public StringGrid getUniqueRows() {
        StringGrid ret = new StringGrid(this);
        ret.stripDuplicateRows();
        return ret;
    }

    public void stripDuplicateRows() {
        Set<List<String>> set = new HashSet<>(this);
        clear();
        addAll(set);
    }

}