Java tutorial
/* * Module Name: hcstools * This module is a plugin for the KNIME platform <http://www.knime.org/> * * Copyright (c) 2011. * Max Planck Institute of Molecular Cell Biology and Genetics, Dresden * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * Detailed terms and conditions are described in the license.txt. * also see <http://www.gnu.org/licenses/>. */ package de.mpicbg.knime.hcs.base.utils; import org.apache.commons.math.random.RandomData; import org.apache.commons.math.random.RandomDataImpl; import org.apache.commons.math.stat.descriptive.DescriptiveStatistics; /** * Author: Felix Meyenhofer * Date: 8/19/11 * Time: 1:35 PM * <p/> * Calss to calculate the mutual information between to random variables using the histogram based approach according to * "Moddemeijer R., A statistic to estimate the variance of the histogram based mutual information * estimator based on dependent pairs of observations , Signal Processing, 1999, vol. 75, nr. 1, pp. 51-63" * <p/> * TODO investigate if the bootstraping (which is not a nice solution) could be replaced with histogram normalization. */ public class MutualInformation { private Double logbase = 2.0; // logarithmic logbase private String method = "biased"; private boolean linkaxes = true; // calculate one single axe using the data for the combined data from x and y. private int Nx = 100; // Number of bins for vector x private int Ny = 100; // Number of bins for vector y private Double[] x; // vector containing samples of variable X private Double[] y; // vector containing samples of variable Y // Constructors public MutualInformation() { } public MutualInformation(String method, int n, double logbase) { this.set_binning(n); this.set_method(method); this.set_base(logbase); } public MutualInformation(Double[] x, Double[] y) { this.set_vectors(x, y); this.set_binning(); } public MutualInformation(Double[] x, Double[] y, int n) { this.set_vectors(x, y); this.set_binning(n); } public MutualInformation(Double[] x, Double[] y, int nx, int ny) { this.set_vectors(x, y); this.set_binning(nx, ny); } public MutualInformation(Double[] x, Double[] y, String method) { this(x, y); this.set_method(method); } // Setter public void set_base(Double b) { if (!(b == 2 || b == Math.exp(1))) System.out.println("The logbase is usually choosen to be 2 or e."); if (b < 0) throw new RuntimeException("The logbase has to be a positive Real number"); logbase = b; } public void set_binning() { int bins = (int) Math.ceil(Math.pow(Math.max(x.length, y.length), 1.0 / 3.0)); set_binning(bins); } public void set_binning(int n) { if (n < 1) throw new RuntimeException("The number of bins of " + n + " is too small." + ". Probably there are too few samples, or consider setting the number of bins in the sintanciation."); Nx = n; Ny = n; } public void set_binning(int n1, int n2) { if (n1 < 1 || n2 < 1) throw new RuntimeException("The number of bins of " + n1 + "or " + n2 + " is too small." + "Probably there are too few samples, or consider setting the number of bins in the sintanciation."); Nx = n1; Ny = n2; } public void set_xvector(Double[] v) { x = v; } public void set_yvector(Double[] v) { y = v; } public void set_vectors(Double[] v1, Double[] v2) { x = v1; y = v2; } public void set_method(String method) { this.method = method; } public void set_axeslinking(boolean flag) { this.linkaxes = flag; } // Getter public int[] get_binning() { return new int[] { this.Nx, this.Ny }; } public Double[] get_xvector() { return this.x; } public Double[] get_yvector() { return this.y; } public String get_method() { return this.method; } public double get_logbase() { return this.logbase; } // Methods public Double[] calculate() throws Exception { if (x.length < 10 || y.length < 10) throw new RuntimeException("Too few samples."); Double[] res; if (method.contentEquals("unbiased")) { res = unbiased(); } else if (method.contentEquals("biased")) { res = biased(); } else if (method.contentEquals("mms_estimate")) { res = mms_estimate(); } else { throw new RuntimeException("The method '" + method + "' is unknown."); } res = basetransform(res, logbase); return res; } // Private helper methods private Double[] unbiased() { Double[] values = biased(); values[0] = values[0] - values[2]; values[2] = 0.0; return values; } private Double[] mms_estimate() { Double[] values = biased(); values[0] = values[0] - values[2]; Double lambda = Math.pow(values[0], 2) / (Math.pow(values[0], 2) + Math.pow(values[1], 2)); values[2] = (1 - lambda) * values[0]; values[0] = lambda * values[0]; values[1] = lambda * values[1]; return values; } private Double[] biased() { double[][] H = histogram2(); // total-sum, row-sum and column-sum int r = H.length; int c = H[1].length; double[] Hx = new double[r]; double[] Hy = new double[c]; int count = 0; for (int i = 0; i < r; i++) { for (int j = 0; j < c; j++) { Hx[i] += H[i][j]; Hy[i] += H[j][i]; count += H[i][j]; } } // Calculate mutual information. Double mutualinfo = 0.0; Double sigma = 0.0; Double logf; for (int i = 0; i < r; i++) { for (int j = 0; j < c; j++) { logf = log(H[i][j], Hx[i], Hy[j]); mutualinfo += H[i][j] * logf; sigma += H[i][j] * Math.pow(logf, 2); } } mutualinfo /= count; sigma = Math.sqrt((sigma / count - Math.pow(mutualinfo, 2)) / (count - 1)); mutualinfo += Math.log(count); Double bias = (double) (r - 1) * (c - 1) / (2 * count); return new Double[] { mutualinfo, sigma, bias }; } private double[][] histogram2() { // Get min and max of the scale(s) Double[] mima1; Double[] mima2; if (linkaxes) { mima1 = minmax(x, y); mima2 = mima1; } else { mima1 = minmax(x); mima2 = minmax(y); } // Calculate upper and lower bounds Double de1 = (mima1[1] - mima1[0]) / (x.length - 1); Double lb1 = mima1[0] - de1 / 2; Double ub1 = mima1[1] + de1 / 2; Double ra1 = (ub1 - lb1); Double de2 = (mima2[1] - mima2[0]) / (y.length - 1); Double lb2 = mima2[0] - de2 / 2; Double ub2 = mima2[1] + de2 / 2; Double ra2 = (ub2 - lb2); // Bring the vectors to the same length. if (x.length < y.length) { System.out.println("Warning: the vector lenghts (currrent:" + x.length + "," + y.length + ") need to be equal. Bottstrapped x."); x = bootstrap(x, y.length); } else if (x.length > y.length) { System.out.println("Warning: the vector lenghts (currrent:" + x.length + "," + y.length + ") need to be equal. Bottstrapped y."); y = bootstrap(y, x.length); } // Correct the binning. if ((Nx >= x.length) || (Ny >= y.length)) { System.out.println("Binning exceeded vector length and was set to" + Nx + "."); set_binning(); } // Compute the histogram/probability double[][] prob = new double[Nx][Ny]; for (int i = 0; i < x.length; i++) { int ind1 = (int) Math.round((x[i] - lb1) / ra1 * Nx + 0.5); int ind2 = (int) Math.round((y[i] - lb2) / ra2 * Ny + 0.5); if ((1 <= ind1) & (ind1 <= Nx) & (1 <= ind2) & (ind2 <= Ny)) { prob[ind1 - 1][ind2 - 1] += 1; } } return prob; } private Double[] minmax(Double[] vect) { DescriptiveStatistics stats = new DescriptiveStatistics(); for (Double value : vect) { stats.addValue(value); } return new Double[] { stats.getMin(), stats.getMax() }; } private Double[] minmax(Double[] vect1, Double[] vect2) { DescriptiveStatistics stats = new DescriptiveStatistics(); for (Double value : vect1) { stats.addValue(value); } for (Double value : vect2) { stats.addValue(value); } return new Double[] { stats.getMin(), stats.getMax() }; } private Double log(Double hxy, Double hx, Double hy) { if ((hxy < 1e-6)) // || (hy < 1e-6) || (hxy < 1e6) ) return 0.0; else return Math.log(hxy / hx / hy); } private Double[] basetransform(Double[] v, Double b) { for (int i = 0; i < v.length; i++) { v[i] /= Math.log(b); } return v; } private Double[] bootstrap(Double[] v, int Nboot) { Double[] boot = new Double[Nboot]; int I; int maxI = v.length - 1; RandomData rand = new RandomDataImpl(); for (int r = 0; r < Nboot; ++r) { I = rand.nextInt(0, maxI); boot[r] = v[I]; } return boot; } // Testing public static void main(String[] args) throws Exception { Double[] a = new Double[] { 1.0, 2.0, 2.0, 2.0, 0.0, 0.0, 1.0, 0.0, 1.0, 2.0 }; Double[] b = new Double[] { 1.0, 2.0, 2.0, 2.0, 2.0, 1.0, 0.0, 2.0, 1.0, 0.0 }; MutualInformation mutinf = new MutualInformation(a, b, 3); Double[] res; res = mutinf.calculate(); System.err.println("mutual information (" + mutinf.method + ", log" + mutinf.logbase + "): " + res[0] + ", sigma: " + res[1] + ", bias: " + res[2]); mutinf.set_method("biased"); res = mutinf.calculate(); System.err.println("mutual information (" + mutinf.method + ", log" + mutinf.logbase + "): " + res[0] + ", sigma: " + res[1] + ", bias: " + res[2]); } }