Java tutorial
/* * Copyright 2014 Steve Ash * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.steveash.jg2p.align; import com.google.common.collect.HashBasedTable; import com.google.common.collect.Sets; import com.google.common.collect.Table; import com.carrotsearch.hppc.ObjectDoubleMap; import com.carrotsearch.hppc.ObjectDoubleOpenHashMap; import com.github.steveash.jg2p.seq.StringListToTokenSequence; import org.apache.commons.lang3.tuple.Pair; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; import static com.github.steveash.jg2p.util.Assert.assertProb; /** * Table of probabilities from Xi to Yi * * @author Steve Ash */ public class ProbTable implements Iterable<Table.Cell<String, String, Double>>, Externalizable { private static final long serialVersionUID = -8001165446102770332L; public static final double minLogProb = -1e12; /** * Returns a set of all non-empty x,y pairs from a unioned with all non-empty x,y pairs from b * @param a * @param b * @return */ public static Set<Pair<String, String>> unionOfAllCells(ProbTable a, ProbTable b) { Set<Pair<String, String>> xys = Sets.newHashSetWithExpectedSize(Math.max(a.xyProb.size(), b.xyProb.size())); addAllPresent(a, xys); addAllPresent(b, xys); return xys; } private static void addAllPresent(ProbTable tbl, Set<Pair<String, String>> output) { for (Table.Cell<String, String, Double> aa : tbl) { if (aa.getValue() != null && aa.getValue() > 0) { output.add(Pair.of(aa.getRowKey(), aa.getColumnKey())); } } } @Override public Iterator<Table.Cell<String, String, Double>> iterator() { return xyProb.cellSet().iterator(); } public Map<String, Double> getYProbForX(String x) { return xyProb.row(x); } public static class Marginals { private final ObjectDoubleMap<String> xMarginals; private final ObjectDoubleMap<String> yMarginals; private final double sumJointMass; // sum of all probability mass across all X x Y joint distrib Marginals(ObjectDoubleMap<String> xMarginals, ObjectDoubleMap<String> yMarginals, double sumJointMass) { this.xMarginals = xMarginals; this.yMarginals = yMarginals; this.sumJointMass = sumJointMass; } public double probY(String y) { return assertProb(yMarginals.getOrDefault(y, -1)); } public double probX(String x) { return assertProb(xMarginals.getOrDefault(x, 0)); } public int countY() { return yMarginals.size(); } public int countX() { return xMarginals.size(); } public double sumOfAllJointProbabilities() { return sumJointMass; } } private /*final*/ Table<String, String, Double> xyProb = HashBasedTable.create(); public ProbTable() { } public double prob(String x, String y) { Double maybe = xyProb.get(x, y); if (maybe == null) { return 0; } return maybe; } public void clear() { xyProb.clear(); } public void setProb(String x, String y, double value) { xyProb.put(x, y, value); } public void addProb(String x, String y, double valueToAdd) { Double maybe = xyProb.get(x, y); if (maybe == null) maybe = 0.0; xyProb.put(x, y, maybe + valueToAdd); } public long entryCount() { return xyProb.size(); } public Marginals calculateMarginals() { ObjectDoubleOpenHashMap<String> x = ObjectDoubleOpenHashMap.newInstance(); ObjectDoubleOpenHashMap<String> y = ObjectDoubleOpenHashMap.newInstance(); double sum = 0; for (Table.Cell<String, String, Double> cell : xyProb.cellSet()) { x.putOrAdd(cell.getRowKey(), cell.getValue(), cell.getValue()); y.putOrAdd(cell.getColumnKey(), cell.getValue(), cell.getValue()); sum += cell.getValue(); } return new Marginals(x, y, sum); } public ProbTable makeNormalizedCopy() { ProbTable result = new ProbTable(); Marginals marginals = this.calculateMarginals(); double sum = marginals.sumOfAllJointProbabilities(); for (Table.Cell<String, String, Double> cell : this) { double normalValue = cell.getValue() / sum; result.setProb(cell.getRowKey(), cell.getColumnKey(), normalValue); } return result; } public Set<String> xRows() { return xyProb.rowKeySet(); } public Set<String> yCols() { return xyProb.columnKeySet(); } @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(this.xyProb); } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { this.xyProb = (Table<String, String, Double>) in.readObject(); } }