Java tutorial
/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.util; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.deeplearning4j.berkeley.Counter; import org.deeplearning4j.berkeley.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.util.*; import static org.deeplearning4j.berkeley.StringUtils.splitOnCharWithQuoting; /** * String matrix * @author Adam Gibson * */ public class StringGrid extends ArrayList<List<String>> { private static final long serialVersionUID = 4702427632483221813L; private String sep; private int numColumns = -1; private static final Logger log = LoggerFactory.getLogger(StringGrid.class); public final static String NONE = "NONE"; public StringGrid(StringGrid grid) { this.sep = grid.sep; this.numColumns = grid.numColumns; addAll(grid); fillOut(); } public StringGrid(String sep, int numColumns) { this(sep, new ArrayList<String>()); this.numColumns = numColumns; fillOut(); } public int getNumColumns() { return numColumns; } private void fillOut() { for (List<String> list : this) { if (list.size() < numColumns) { int diff = numColumns - list.size(); for (int i = 0; i < diff; i++) { list.add(NONE); } } } } public static StringGrid fromFile(String file, String sep) throws IOException { List<String> read = FileUtils.readLines(new File(file)); if (read.isEmpty()) throw new IllegalStateException("Nothing to read; file is empty"); return new StringGrid(sep, read); } public static StringGrid fromInput(InputStream from, String sep) throws IOException { List<String> read = IOUtils.readLines(from); if (read.isEmpty()) throw new IllegalStateException("Nothing to read; file is empty"); return new StringGrid(sep, read); } public StringGrid(String sep, Collection<String> data) { super(); this.sep = sep; List<String> list = new ArrayList<>(data); for (int i = 0; i < list.size(); i++) { String line = list.get(i).trim(); //text delimiter if (line.indexOf('\"') > 0) { Counter<Character> counter = new Counter<>(); for (int j = 0; j < line.length(); j++) { counter.incrementCount(line.charAt(j), 1.0); } if (counter.getCount('"') > 1) { String[] split = splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\'); add(new ArrayList<>(Arrays.asList(split))); } else { List<String> row = new ArrayList<>( Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\'))); if (numColumns < 0) numColumns = row.size(); else if (row.size() != numColumns) log.warn("Row " + i + " had invalid number of columns line was " + line); add(row); } } else { List<String> row = new ArrayList<>( Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\'))); if (numColumns < 0) numColumns = row.size(); else if (row.size() != numColumns) { log.warn("Could not add " + line); } add(row); } } fillOut(); } /** * Removes all rows with a column of NONE * @param column the column to remove by */ public void removeRowsWithEmptyColumn(int column) { List<List<String>> remove = new ArrayList<>(); for (List<String> list : this) { if (list.get(column).equals(NONE)) remove.add(list); } removeAll(remove); } public void head(int num) { if (num >= size()) num = size(); StringBuilder builder = new StringBuilder(); for (int i = 0; i < num; i++) { builder.append(get(i) + "\n"); } log.info(builder.toString()); } /** * Removes the specified columns from the grid * @param columns the columns to remove */ public void removeColumns(Integer... columns) { if (columns.length < 1) throw new IllegalArgumentException("Columns must contain at least one column"); List<Integer> removeOrder = Arrays.asList(columns); //put them in the right order for removing Collections.sort(removeOrder); for (List<String> list : this) { List<String> remove = new ArrayList<>(); for (int i = 0; i < columns.length; i++) { remove.add(list.get(columns[i])); } list.removeAll(remove); } } /** * Removes all rows with a column of missingValue * @param column he column to remove by * @param missingValue the missingValue sentinel value */ public void removeRowsWithEmptyColumn(int column, String missingValue) { List<List<String>> remove = new ArrayList<>(); for (List<String> list : this) { if (list.get(column).equals(missingValue)) remove.add(list); } removeAll(remove); } public List<List<String>> getRowsWithColumnValues(Collection<String> values, int column) { List<List<String>> ret = new ArrayList<>(); for (List<String> val : this) { if (values.contains(val.get(column))) ret.add(val); } return ret; } public void sortColumnsByWordLikelihoodIncluded(final int column) { final Counter<String> counter = new Counter<>(); List<String> col = getColumn(column); for (String s : col) { StringTokenizer tokenizer = new StringTokenizer(s); while (tokenizer.hasMoreTokens()) { counter.incrementCount(tokenizer.nextToken(), 1.0); } } if (counter.totalCount() <= 0.0) { log.warn("Unable to calculate probability; nothing found"); return; } //laplace smoothing counter.incrementAll(counter.keySet(), 1.0); Set<String> remove = new HashSet<>(); for (String key : counter.keySet()) if (key.length() < 2 || key.matches("[a-z]+")) remove.add(key); for (String key : remove) counter.removeKey(key); counter.pruneKeysBelowThreshold(4.0); final double totalCount = counter.totalCount(); Collections.sort(this, new Comparator<List<String>>() { @Override public int compare(List<String> o1, List<String> o2) { double c1 = sumOverTokens(counter, o1.get(column), totalCount); double c2 = sumOverTokens(counter, o2.get(column), totalCount); return Double.compare(c1, c2); } }); } /* Return the log sum of the column relative to the word frequencies (equivalent to the probability in log space */ private double sumOverTokens(Counter<String> counter, String column, double totalCount) { StringTokenizer tokenizer = new StringTokenizer(column); double count = 0; while (tokenizer.hasMoreTokens()) count += Math.log(counter.getCount(column) / totalCount); return count; } public StringCluster clusterColumn(int column) { return new StringCluster(getColumn(column)); } public void dedupeByClusterAll() { for (int i = 0; i < size(); i++) dedupeByCluster(i); } /** * Deduplicate based on the column clustering signature * @param column */ public void dedupeByCluster(int column) { StringCluster cluster = clusterColumn(column); System.out.println(cluster.get("family mcdonalds restaurant")); System.out.println(cluster.get("family mcdonalds restaurants")); List<Map<String, Integer>> list2 = cluster.getClusters(); for (int i = 0; i < list2.size(); i++) { if (list2.get(i).size() > 1) { System.out.println(list2.get(i)); } } FingerPrintKeyer keyer = new FingerPrintKeyer(); Set<Integer> alreadyDeDupped = new HashSet<>(); for (int i = 0; i < size(); i++) { String key = keyer.key(get(i).get(column)); Map<String, Integer> map = cluster.get(key); if (map != null && map.size() > 1) { List<Integer> list = filterRowsByColumn(column, map.keySet()); //deduplication to do if (list.size() > 1) modifyRows(alreadyDeDupped, column, list, map); } } } /** * Cleans up the rows specified that haven't already been deduplified * @param alreadyDeDupped the already dedupped rows * @param column the column to homogenize * @param rows the rows to preProcess * @param cluster the cluster of values */ private void modifyRows(Set<Integer> alreadyDeDupped, Integer column, List<Integer> rows, Map<String, Integer> cluster) { String chosenKey = null; Integer max = null; for (Map.Entry<String, Integer> entry : cluster.entrySet()) { String key = entry.getKey(); int value = entry.getValue(); StringTokenizer val = new StringTokenizer(key); List<String> list = new ArrayList<>(); boolean allLower = true; outer: while (val.hasMoreTokens()) { String token = val.nextToken(); //weird capitalization if (token.length() >= 3 && token.matches("[A-Z]+")) { continue outer; } list.add(token); } for (String s : list) { allLower = allLower && s.matches("[a-z]+"); } if (allLower) { continue; } //not a proper name if (list.get(list.size() - 1).toLowerCase().equals("the")) { continue; } //first selection that's valid or count is higher if (max == null || (!allLower && value > max)) { max = value; chosenKey = key; } } //wtf is wrong with you people? if (chosenKey == null) { //getFromOrigin the max value of the cluster String max2 = maximalValue(cluster); StringTokenizer val = new StringTokenizer(max2); List<String> list = new ArrayList<>(); while (val.hasMoreTokens()) { String token = val.nextToken(); //weird capitalization if (token.length() >= 3 && token.matches("[A-Z]+")) { token = token.charAt(0) + token.substring(1).toLowerCase(); } list.add(token); } boolean allLower = true; for (String s : list) allLower = allLower && s.matches("[a-z]+"); if (list.get(list.size() - 1).toLowerCase().equals("the")) { max2 = max2.replaceAll("^[Tt]he", ""); } if (allLower) max2 = StringUtils.capitalize(max2); chosenKey = max2; } for (Integer i2 : rows) { //row already processed if (!alreadyDeDupped.contains(i2)) { disambiguateRow(i2, column, chosenKey); } } } private String maximalValue(Map<String, Integer> map) { Counter<String> counter = new Counter<>(); for (Map.Entry<String, Integer> entry : map.entrySet()) { counter.incrementCount(entry.getKey(), entry.getValue()); } return counter.argMax(); } private void disambiguateRow(Integer row, Integer column, String chosenValue) { System.out.println("SETTING " + row + " column " + column + " to " + chosenValue); get(row).set(column, chosenValue); } public List<Integer> filterRowsByColumn(int column, Collection<String> values) { List<Integer> list = new ArrayList<>(); for (int i = 0; i < size(); i++) { if (values.contains(get(i).get(column))) list.add(i); } return list; } public void sortBy(final int column) { Collections.sort(this, new Comparator<List<String>>() { @Override public int compare(List<String> o1, List<String> o2) { return o1.get(column).compareTo(o2.get(column)); } }); } public List<String> toLines() { List<String> lines = new ArrayList<>(); for (List<String> list : this) { StringBuilder sb = new StringBuilder(); for (String s : list) { sb.append(s.replaceAll(sep, " ")); sb.append(sep); } lines.add(sb.toString().substring(0, sb.lastIndexOf(sep))); } return lines; } public void swap(int column1, int column2) { List<String> col1 = getColumn(column1); List<String> col2 = getColumn(column2); for (int i = 0; i < size(); i++) { get(i).set(column1, col2.get(i)); get(i).set(column2, col1.get(i)); } } public void merge(int column1, int column2) { checkInvalidColumn(column1); checkInvalidColumn(column2); if (column1 != column2) for (List<String> list : this) { StringBuilder sb = new StringBuilder(); sb.append(list.get(column1)); sb.append(list.get(column2)); list.set(Math.min(column1, column2), sb.toString().replaceAll("\"", "").replace(sep, " ")); list.remove(Math.max(column1, column2)); } numColumns--; } public StringGrid getAllWithSimilarity(double threshold, int firstColumn, int secondColumn) { for (int column : new int[] { firstColumn, secondColumn }) checkInvalidColumn(column); StringGrid grid = new StringGrid(sep, numColumns); for (List<String> list : this) { double sim = MathUtils.stringSimilarity(list.get(firstColumn), list.get(secondColumn)); if (sim >= threshold) grid.addRow(list); } return grid; } public void writeLinesTo(String path) throws IOException { FileUtils.writeLines(new File(path), toLines()); } public void fillDown(String value, int column) { checkInvalidColumn(column); for (List<String> list : this) list.set(column, value); } public StringGrid select(int column, String value) { StringGrid grid = new StringGrid(sep, numColumns); for (int i = 0; i < size(); i++) { List<String> row = get(i); if (row.get(column).equals(value)) { grid.addRow(row); } } return grid; } public void split(int column, String sepBy) { List<String> col = getColumn(column); int validate = -1; Set<String> remove = new HashSet<>(); for (int i = 0; i < col.size(); i++) { String s = col.get(i); String[] split2 = StringUtils.splitOnCharWithQuoting(s, sepBy.charAt(0), '"', '\\'); if (validate < 0) validate = split2.length; else if (validate != split2.length) { log.warn("Row " + get(i) + " will be invalid after split; removing"); remove.add(s); } } for (String s : remove) { StringGrid grid = select(column, s); removeAll(grid); } Map<Integer, List<String>> replace = new HashMap<>(); for (int i = 0; i < size(); i++) { List<String> list = get(i); List<String> newList = new ArrayList<>(); String split = list.get(column); String[] split2 = StringUtils.splitOnCharWithQuoting(split, sepBy.charAt(0), '"', '\\'); //add right next to where column was split for (int j = 0; j < list.size(); j++) { if (j == column) for (String s : split2) newList.add(s); else newList.add(list.get(j)); } replace.put(i, newList); } //prevent concurrent modification for (Map.Entry<Integer, List<String>> entry : replace.entrySet()) { set(entry.getKey(), entry.getValue()); } } public void filterBySimilarity(double threshold, int firstColumn, int secondColumn) { for (int column : new int[] { firstColumn, secondColumn }) checkInvalidColumn(column); List<List<String>> remove = new ArrayList<>(); for (List<String> list : this) { double sim = MathUtils.stringSimilarity(list.get(firstColumn), list.get(secondColumn)); if (sim < threshold) remove.add(list); } removeAll(remove); } public void prependToEach(String prepend, int toColumn) { for (List<String> row : this) { String currVal = row.get(toColumn); row.set(toColumn, prepend + currVal); } } public void appendToEach(String append, int toColumn) { for (List<String> row : this) { String currVal = row.get(toColumn); row.set(toColumn, currVal + append); } } public void addColumn(List<String> column) { if (column.size() != this.size()) throw new IllegalArgumentException("Unable to add column; not enough rows"); for (int i = 0; i < size(); i++) { get(i).add(column.get(i)); } } /** * Combine the column based on a template and a number of template variable * columns. Note that this will also collapse the columns specified (removing them) * * @param templateColumn the column with the template ( uses printf style templating) * @param paramColumns the columns with template variables */ public void combineColumns(int templateColumn, Integer[] paramColumns) { for (List<String> list : this) { List<String> format = new ArrayList<>(); for (int j : paramColumns) format.add(list.get(j)); list.set(templateColumn, String.format(list.get(templateColumn), (Object[]) format.toArray(new String[] {}))); //collapse columns list.removeAll(format); } } /** * Combine the column based on a template and a number of template variable * columns. Note that this will also collapse the columns specified (removing them) * * @param templateColumn the column with the template ( uses printf style templating) * @param paramColumns the columns with template variables */ public void combineColumns(int templateColumn, int[] paramColumns) { for (List<String> list : this) { List<String> format = new ArrayList<>(); for (int j : paramColumns) format.add(list.get(j)); list.set(templateColumn, String.format(list.get(templateColumn), (Object[]) format.toArray(new String[] {}))); //collapse columns list.removeAll(format); } } public void addRow(List<String> row) { if (row.isEmpty()) { log.warn("Unable to add empty row"); } else if (!isEmpty() && row.size() != get(0).size()) { log.warn("Unable to add row; not the same number of columns"); } else add(row); } public Map<String, List<List<String>>> mapByPrimaryKey(int columnKey) { Map<String, List<List<String>>> map = new HashMap<>(); for (List<String> line : this) { String val = line.get(columnKey); List<List<String>> get = map.get(val); if (get == null) { get = new ArrayList<>(); map.put(val, get); } get.add(new ArrayList<>(Arrays.asList(sep))); } return map; } public List<String> getRow(int row) { checkInvalidRow(row); return new ArrayList<>(get(row)); } public List<String> getColumn(int column) { checkInvalidColumn(column); List<String> ret = new ArrayList<>(); for (List<String> list : this) { ret.add(list.get(column)); } return ret; } private void checkInvalidRow(int row) { if (row < 0 || row >= size()) throw new IllegalArgumentException("Row does not exist"); } private void checkInvalidColumn(int column) { if (column < 0 || column >= numColumns) throw new IllegalArgumentException("Invalid column " + column); } public StringGrid getRowsWithDuplicateValuesInColumn(int column) { checkInvalidColumn(column); StringGrid grid = new StringGrid(sep, numColumns); List<String> columns = getColumn(column); Counter<String> counter = new Counter<>(); for (String val : columns) counter.incrementCount(val, 1.0); counter.pruneKeysBelowThreshold(2.0); Set<String> keys = counter.keySet(); for (List<String> row : this) { for (String key : keys) if (row.get(column).equals(key)) grid.addRow(row); } return grid; } public StringGrid getRowWithOnlyOneOccurrence(int column) { checkInvalidColumn(column); StringGrid grid = new StringGrid(sep, numColumns); List<String> columns = getColumn(column); Counter<String> counter = new Counter<>(); for (String val : columns) counter.incrementCount(val, 1.0); Set<String> keys = new HashSet<>(counter.keySet()); for (String key : keys) { if (counter.getCount(key) > 1) { counter.removeKey(key); } } for (List<String> row : this) { for (String key : keys) if (row.get(column).equals(key)) grid.addRow(row); } return grid; } public StringGrid getUniqueRows() { StringGrid ret = new StringGrid(this); ret.stripDuplicateRows(); return ret; } public void stripDuplicateRows() { Set<List<String>> set = new HashSet<>(this); clear(); addAll(set); } }