Java tutorial
/** * Copyright 2013 Brigham Young University * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package edu.byu.nlp.crowdsourcing.models.gibbs; import java.util.Arrays; import java.util.Iterator; import java.util.Map; import org.apache.commons.math3.random.RandomGenerator; import org.fest.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Preconditions; import edu.byu.nlp.classify.eval.BasicPrediction; import edu.byu.nlp.classify.eval.Prediction; import edu.byu.nlp.crowdsourcing.CrowdsourcingUtils; import edu.byu.nlp.crowdsourcing.MultiAnnState; import edu.byu.nlp.crowdsourcing.MultiAnnState.BasicMultiAnnState; import edu.byu.nlp.crowdsourcing.PriorSpecification; import edu.byu.nlp.crowdsourcing.TrainableMultiAnnModel; import edu.byu.nlp.data.types.Dataset; import edu.byu.nlp.data.types.DatasetInstance; import edu.byu.nlp.data.types.SparseFeatureVector; import edu.byu.nlp.data.types.SparseFeatureVector.EntryVisitor; import edu.byu.nlp.dataset.Datasets; import edu.byu.nlp.dataset.SparseFeatureVectors; import edu.byu.nlp.math.GammaFunctions; import edu.byu.nlp.math.Math2; import edu.byu.nlp.math.SparseRealMatrices; import edu.byu.nlp.util.DoubleArrays; import edu.byu.nlp.util.IntArrays; import edu.byu.nlp.util.Matrices; /** * @author rah67 * @author plf1 * */ public class BlockCollapsedMultiAnnModelMath { private static final Logger logger = LoggerFactory.getLogger(TrainableMultiAnnModel.class); /* * public int docIndexFor(FlatInstance<SparseFeatureVector, Integer> instance) { * Iterable<Enumeration<FlatInstance<SparseFeatureVector, Integer>>> enumeration = * Iterables2.enumerate(data.allInstances()); for (Enumeration<Instance<Integer, * SparseFeatureVector>> e : enumeration) { if (e.getElement() == instance) { return e.getIndex(); * } } return -1; } * * public void updateA(FlatInstance<SparseFeatureVector, Integer> instance, int annotator, int label) * { int docIndex = docIndexFor(instance); if (docIndex < 0) { throw new * IllegalArgumentException("Couldn't find instance"); } ++a[docIndex][annotator][label]; } */ static boolean hasCorrectCounts(int[] y, int[] m, double[] logCountOfY, double[][][] countOfJYAndA, int[][][] a, double[][] numAnnsPerJAndY, int[][] docJCount, double[][] countOfMAndX, double[] numFeaturesPerM, double[] docSize, double[][] logCountOfYAndM, double[] logSumCountOfYAndM, int numAnnotators, int numLabels, int numFeatures, Dataset data, PriorSpecification priors) { if (data == null || priors == null) { logger.warn("Ignoring the hasCorrectCounts assertion because priors and/or data have not been set"); return true; } double[] logCountOfYSanity = DoubleArrays.of(priors.getBTheta(), numLabels); for (int i = 0; i < y.length; i++) { ++logCountOfYSanity[y[i]]; } DoubleArrays.logToSelf(logCountOfYSanity); if (!DoubleArrays.equals(logCountOfYSanity, logCountOfY, 1e-8)) { return false; } { // keeps the int i out of the scope of other loops double[][] countOfMAndXSanity = Matrices.of(priors.getBPhi(), numLabels, numFeatures); int i = 0; for (DatasetInstance instance : data) { for (SparseFeatureVector.Entry e : instance.asFeatureVector().sparseEntries()) { countOfMAndXSanity[m[i]][e.getIndex()] += e.getValue(); } ++i; } if (!Matrices.equals(countOfMAndXSanity, countOfMAndX, 1e-8)) { return false; } } double[][] logCountOfYAndMSanity = new double[numLabels][numLabels]; CrowdsourcingUtils.initializeConfusionMatrixWithPrior(logCountOfYAndMSanity, priors.getBMu(), priors.getCMu()); for (int i = 0; i < y.length; i++) { ++logCountOfYAndMSanity[y[i]][m[i]]; } Matrices.logToSelf(logCountOfYAndMSanity); if (!Matrices.equals(logCountOfYAndMSanity, logCountOfYAndM, 1e-8)) { return false; } // Assumes a[][][] is correct for (int j = 0; j < countOfJYAndA.length; j++) { double[][] countOfJYAndASanity = new double[numLabels][numLabels]; CrowdsourcingUtils.initializeConfusionMatrixWithPrior(countOfJYAndASanity, priors.getBGamma(j), priors.getCGamma()); for (int i = 0; i < y.length; i++) { for (int k = 0; k < numLabels; k++) { countOfJYAndASanity[y[i]][k] += a[i][j][k]; } } if (!Matrices.equals(countOfJYAndASanity, countOfJYAndA[j], 1e-8)) { return false; } } return true; } @VisibleForTesting static double getMeanSquaredDistanceFrom01(double[][] mat) { double mse = 0; for (int r = 0; r < mat.length; r++) { double[] row = mat[r]; for (int c = 0; c < row.length; c++) { double val = mat[r][c]; assert 0 <= val && val <= 1; double dist = (val < .5) ? val : 1 - val; mse += dist * dist; } } return mse; } public static double[][] confusionMatrix(Boolean labeled, int[] gold, int[] guesses, int numLabels, Dataset data) { Preconditions.checkArgument(!(data == null && labeled != null)); Iterator<DatasetInstance> itr = (data == null) ? null : data.iterator(); double[][] confusion = new double[numLabels][numLabels]; for (int i = 0; i < guesses.length; i++) { // ignore truly unlabeled data (label=-1) if (gold[i] == -1) { continue; } DatasetInstance inst = (itr == null) ? null : itr.next(); // "unlabeled" (labels exist, but are hidden) boolean hasAnnotations = SparseRealMatrices.sum(inst.getAnnotations().getLabelAnnotations()) != 0; if (labeled == null || (!labeled && !hasAnnotations)) { ++confusion[gold[i]][guesses[i]]; } // labeled if (labeled == null || (labeled && hasAnnotations)) { ++confusion[gold[i]][guesses[i]]; } } return confusion; } public static enum DiagonalizationMethod { NONE, GOLD, AVG_GAMMA, MAX_GAMMA, RAND } /** * @param gold (optional) If the GOLD diagonalization method is to * be used, gold standard labels must also be passed in. */ public static MultiAnnState fixLabelSwitching(MultiAnnState sample, DiagonalizationMethod diagonalizationMethod, int goldInstancesForDiagonalization, boolean diagonalizationWithFullConfusionMatrix, RandomGenerator rnd) { // public static MultiAnnSample fixLabelSwitching(MultiAnnSample sample, int[] gold, boolean labelSwitchingCheatUsesFullConfusionMatrix, Dataset data) { // manually re-order the labels according to y and then // m so that mu and alpha_j line up with their greatest entries // along the diagonal as much as possible. // This helps alleviate the problem of label switching. // Note that we are not changing the meaning of labels // globally (in the data elicited from annotators); // merely ensuring that the features-only-ML // will be most likely to assign a Y=0 assignment // the label 0; a Y=1 assignment the label 1, and // so forth (rather than learning to systematically // map Y=0 to 2, for example, which would be a setting // with just as much probability as the Y=0 to 0 setting. // It's important to NOT use the counts (e.g., logCountOfYAndM) // We need to look at actual normalized accuracies (e.g., mu). int[] y = sample.getY().clone(); int[] m = sample.getM().clone(); double[][] mu = Matrices.clone(sample.getMu()); double[][] meanMu = Matrices.clone(sample.getMeanMu()); double[][][] alpha = Matrices.clone(sample.getAlpha()); double[][][] meanAlpha = Matrices.clone(sample.getMeanAlpha()); double[] theta = sample.getTheta().clone(); double[] meanTheta = sample.getMeanTheta().clone(); double[][] logPhi = Matrices.clone(sample.getLogPhi()); double[][] meanLogPhi = Matrices.clone(sample.getMeanLogPhi()); // -------------- Fix Y ----------------------- // int[] yMap; int[] gold; switch (diagonalizationMethod) { case NONE: logger.info("Not Diagonalizing"); // diagonal mapping (no change) yMap = IntArrays.sequence(0, mu.length); break; case RAND: logger.info("Diagonalizing randomly"); // randomly shuffled mapping yMap = IntArrays.shuffled(IntArrays.sequence(0, mu.length), rnd); break; case GOLD: logger.info("Diagonalizing based on gold 'heldout data'"); // create a confusion matrix by comparing gold labels with model predictions (gold labels are constructed to match the model ordering) Boolean useLabeledConfusionMatrix = diagonalizationWithFullConfusionMatrix ? null : true; gold = Datasets.concealedLabels(sample.getData(), sample.getInstanceIndices()); int numGoldInstances = goldInstancesForDiagonalization == -1 ? gold.length : goldInstancesForDiagonalization; gold = Arrays.copyOfRange(gold, 0, numGoldInstances); int[] guesses = Arrays.copyOfRange(sample.getY(), 0, numGoldInstances); double[][] confusions = confusionMatrix(useLabeledConfusionMatrix, gold, guesses, sample.getNumLabels(), sample.getData()); // in a CONFUSION matrix, columns correspond to the latent variable y. So // permute columns to find a good diagonalization yMap = Matrices.getColReorderingForStrongDiagonal(confusions); break; case AVG_GAMMA: // in a gamma matrix, rows correspond to the latent variable y, so permute rows // to find a good diagonalization logger.info("Diagonalizing based on average alpha"); double[][] cumulativeAlphaMean = new double[sample.getNumLabels()][sample.getNumLabels()]; for (int j = 0; j < sample.getNumAnnotators(); j++) { double[][] alphaMean = meanAlpha[j]; Matrices.addToSelf(cumulativeAlphaMean, alphaMean); } yMap = Matrices.getRowReorderingForStrongDiagonal(cumulativeAlphaMean); break; case MAX_GAMMA: logger.info("Diagonalizing based on most confident alpha"); // (pfelt) Find the most definitive alpha matrix // (the one with entries that diverge least from 0 and 1) // We'll map that to be diagonal and then apply its mapping // to all of the other alphas, since alpha matrices // are constrained by the data to be coherent. double[][] bestAlphaMean = null; double min = Double.POSITIVE_INFINITY; for (int j = 0; j < sample.getNumAnnotators(); j++) { double[][] alphaMean = meanAlpha[j]; double error = getMeanSquaredDistanceFrom01(alphaMean); if (error < min) { min = error; bestAlphaMean = alphaMean; } } yMap = Matrices.getNormalizedRowReorderingForStrongDiagonal(bestAlphaMean); break; default: throw new IllegalArgumentException( "unknown diagonalization method: " + diagonalizationMethod.toString()); } logger.info("Y-mapping=" + IntArrays.toString(yMap)); // fix alpha for (int j = 0; j < sample.getNumAnnotators(); j++) { Matrices.reorderRowsToSelf(yMap, alpha[j]); Matrices.reorderRowsToSelf(yMap, meanAlpha[j]); } // fix y for (int i = 0; i < y.length; i++) { y[i] = yMap[y[i]]; } // fix theta Matrices.reorderElementsToSelf(yMap, theta); Matrices.reorderElementsToSelf(yMap, meanTheta); // fix mu Matrices.reorderRowsToSelf(yMap, mu); Matrices.reorderRowsToSelf(yMap, meanMu); // (pfelt) we don't need to update cached values anymore since we're // operating on a sample and not in the context of a model being sampled // // fix logSumCountOfYAndM // Matrices.reorderElementsToSelf(yMap, logSumCountOfYAndM); // // fix numAnnsPerJAndY // Matrices.reorderColsToSelf(yMap, numAnnsPerJAndY); // -------------- Fix M ----------------------- // // (pfelt) We used to sample from mu (by calling mu()) // to get a mu setting. I've changed this to use the params // of mu for two reasons: // 1) a small performance savings // 2) it's easier to test int[] mMap; try { mMap = Matrices.getColReorderingForStrongDiagonal(meanMu); } catch (IllegalArgumentException e) { mMap = new int[meanMu.length]; for (int i = 0; i < mMap.length; i++) { mMap[i] = i; } logger.warn("unable to diagonalize m, returning the identity mapping. " + "If this is itemresp or momresp, then this is fine. " + "If this is multiann, then there is a serious problem."); } // fix mu Matrices.reorderColsToSelf(mMap, mu); Matrices.reorderColsToSelf(mMap, meanMu); // fix m for (int i = 0; i < m.length; i++) { m[i] = mMap[m[i]]; } // fix phi Matrices.reorderRowsToSelf(mMap, logPhi); Matrices.reorderRowsToSelf(mMap, meanLogPhi); // (pfelt) we don't need to update cached values anymore since we're // operating on a sample and not in the context of a model being sampled // // fix numFeaturesPerM // Matrices.reorderElementsToSelf(mMap, numFeaturesPerM); return new BasicMultiAnnState(y, m, theta, meanTheta, logPhi, meanLogPhi, mu, meanMu, alpha, meanAlpha, sample.getData(), sample.getInstanceIndices()); } static boolean isAnnotated(int docIndex, int[][][] a) { int[][] docAnnotations = a[docIndex]; for (int[] arr : docAnnotations) { if (IntArrays.sum(arr) > 0) { return true; } } return false; } static double[] computeMSums(SparseFeatureVector instance, double docSize, double[][] countOfMAndX, double[] numFeaturesPerM, int numLabels, double lambda) { assert instance.sum() * lambda - docSize < 1e-10; double[] mSums = new double[numLabels]; for (int d = 0; d < mSums.length; d++) { // mSums[d] = computeMSum(instance, docSize, countOfMAndX[d], numFeaturesPerM[d]); // (pfelt): normalize and scale current document counts double scaledDocSize = docSize; SparseFeatureVector scaledInstance = instance; if (lambda >= 0) { scaledDocSize = docSize * lambda; // double scale = (1/docSize)*lambda; // scaledDocSize = lambda; scaledInstance = scaledInstance.copy(); SparseFeatureVectors.multiplyToSelf(scaledInstance, lambda); } mSums[d] = computeMSum(scaledInstance, scaledDocSize, countOfMAndX[d], numFeaturesPerM[d]); } return mSums; } @VisibleForTesting static double computeMSum(SparseFeatureVector instance, double docSize, double[] countOfMAndX, double numFeaturesPerM) { // these assertions no long valid with fractional counts // assert Math.round(instance.sum()) == docSize; // assert Math2.doubleEquals(DoubleArrays.sum(countOfMAndX), numFeaturesPerM, 1e-10); // return CollapsedParameters.sumLogOfRisingFactorial(instance, countOfMAndX) // - GammaFunctions.logRisingFactorial(numFeaturesPerM, docSize); return sumLogOfRatioOfGammas(instance, countOfMAndX) - GammaFunctions.logRatioOfGammasByDifference(numFeaturesPerM, docSize); } @VisibleForTesting static double[][] computeYMSums(int docIndex, double[] logSumCountOfYAndM, double[][] logCountOfYAndM, int numLabels, double lambda) { double[][] sums = new double[numLabels][numLabels]; for (int c = 0; c < numLabels; c++) { // assert BlockCollapsedMultiAnnModel.isValidLogSumCountOfYAndM(c, logCountOfYAndM, logSumCountOfYAndM); double logSumCountOfY = GammaFunctions.logRatioOfGammasByDifference(Math.exp(logSumCountOfYAndM[c]), lambda); for (int d = 0; d < numLabels; d++) { sums[c][d] = computeYMSum(logCountOfYAndM[c][d], lambda); sums[c][d] -= logSumCountOfY; } } return sums; } static double computeYMSum(double logCountOfYAndM, double lambda) { return GammaFunctions.logRatioOfGammasByDifference(Math.exp(logCountOfYAndM), lambda); } static double[] computeYSums(int docIndex, Map<Integer, Integer> instanceLabels, double[] logCountOfY, double[][][] countOfJYAndA, double[][] numAnnsPerJAndY, int[][][] a, int[][] docJCount, int numAnnotators, double lambda) { double[] ySums = DoubleArrays.of(0, logCountOfY.length); Integer label = instanceLabels != null && instanceLabels.containsKey(docIndex) ? instanceLabels.get(docIndex) : null; // double[] ySums = logCountOfY.clone(); for (int c = 0; c < ySums.length; c++) { // labeled item (uses delta function prob) if (label != null) { ySums[c] = c == label ? 0 : Double.NEGATIVE_INFINITY; // prob = 0/1 } // no label (just theta prior and annotations) else { // theta ySums[c] = GammaFunctions.logRatioOfGammasByDifference(Math.exp(logCountOfY[c]), lambda); // gamma for (int j = 0; j < numAnnotators; j++) { ySums[c] += computeYSum(countOfJYAndA[j][c], numAnnsPerJAndY[j][c], a[docIndex][j], docJCount[docIndex][j]); } } } return ySums; } // NOTE: EXCLUDES log(b_j + n_c) (see computeYSums) public static double computeYSum(double[] countOfJYAndA, double numAnnsPerJAndY, int[] a_ij, int docJCount) { assert Math2.doubleEquals(numAnnsPerJAndY, DoubleArrays.sum(countOfJYAndA), 1e-10); assert docJCount == IntArrays.sum(a_ij); double ySum = -GammaFunctions.logRisingFactorial(numAnnsPerJAndY, docJCount); for (int k = 0; k < a_ij.length; k++) { if (a_ij[k] > 0) { ySum += GammaFunctions.logRisingFactorial(countOfJYAndA[k], a_ij[k]); } } return ySum; } // TODO: alter this function to return ranked predictions using sampling distributions. public static Iterable<Prediction> predictions(final Dataset data, final int y[], final Map<String, Integer> instanceMap) { return new Iterable<Prediction>() { @Override public Iterator<Prediction> iterator() { return new Iterator<Prediction>() { // only predict on unlabeled portion of dataset; the labeled portion is known private Iterator<DatasetInstance> it = data.iterator(); @Override public boolean hasNext() { return it.hasNext(); } @Override public Prediction next() { DatasetInstance instance = it.next(); // prediction based on most recent sample Integer prediction = null; if (instanceMap.containsKey(instance.getInfo().getRawSource())) { prediction = y[instanceMap.get(instance.getInfo().getRawSource())]; } return new BasicPrediction(prediction, instance); } @Override public void remove() { throw new UnsupportedOperationException(); } }; } }; } private static class SumLogOfRatioOfGammas implements EntryVisitor { private final double[] topicWordCounts; private double acc; public SumLogOfRatioOfGammas(double[] topicWordCounts) { this.topicWordCounts = topicWordCounts; this.acc = 0.0; } @Override public void visitEntry(int index, double value) { acc += GammaFunctions.logRatioOfGammas(topicWordCounts[index] + value, topicWordCounts[index]); } public double getSum() { return acc; } } public static double sumLogOfRatioOfGammas(SparseFeatureVector doc, double[] topicWordCounts) { SumLogOfRatioOfGammas visitor = new SumLogOfRatioOfGammas(topicWordCounts); doc.visitSparseEntries(visitor); return visitor.getSum(); } }