edu.berkeley.compbio.ml.cluster.bayesian.LabelDecomposingBayesianClustering.java Source code

Java tutorial

Introduction

Here is the source code for edu.berkeley.compbio.ml.cluster.bayesian.LabelDecomposingBayesianClustering.java

Source

/*
 * Copyright (c) 2006-2013  David Soergel  <dev@davidsoergel.com>
 * Licensed under the Apache License, Version 2.0
 * http://www.apache.org/licenses/LICENSE-2.0
 */

package edu.berkeley.compbio.ml.cluster.bayesian;

import com.davidsoergel.dsutils.GenericFactory;
import com.davidsoergel.dsutils.GenericFactoryException;
import com.davidsoergel.stats.DissimilarityMeasure;
import com.davidsoergel.stats.DistributionException;
import com.davidsoergel.stats.Multinomial;
import com.google.common.collect.ImmutableMap;
import edu.berkeley.compbio.ml.cluster.AdditiveCentroidCluster;
import edu.berkeley.compbio.ml.cluster.AdditiveClusterable;
import edu.berkeley.compbio.ml.cluster.CentroidCluster;
import edu.berkeley.compbio.ml.cluster.Cluster;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.ClusterRuntimeException;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import edu.berkeley.compbio.ml.cluster.kmeans.GrowableKmeansClustering;
import org.apache.commons.lang.NotImplementedException;
import org.apache.log4j.Logger;

import java.util.HashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;

/**
 * Performs a simple unsupervised clustering on all the samples with a given label in an attempt to decompose the
 * label-level cluster into several smaller clusters.  This is an unusual case in that it both requires a prototype
 * factory and is sample-initialized.
 *
 * @author David Soergel
 * @version $Id$
 */
public class LabelDecomposingBayesianClustering<T extends AdditiveClusterable<T>>
        extends NearestNeighborClustering<T>
//   implements SampleInitializedOnlineClusteringMethod<T>
{
    // ------------------------------ FIELDS ------------------------------

    private static final Logger logger = Logger.getLogger(LabelDecomposingBayesianClustering.class);

    GenericFactory<T> prototypeFactory;

    // --------------------------- CONSTRUCTORS ---------------------------

    /**
     * @param dm                       The distance measure to use
     * @param unknownDistanceThreshold the minimum probability to accept when adding a point to a cluster
     */
    public LabelDecomposingBayesianClustering(final DissimilarityMeasure<T> dm,
            final double unknownDistanceThreshold, final Set<String> potentialTrainingBins,
            final Map<String, Set<String>> predictLabelSets, final ProhibitionModel<T> prohibitionModel,
            final Set<String> testLabels) {
        super(dm, unknownDistanceThreshold, potentialTrainingBins, predictLabelSets, prohibitionModel, testLabels);
    }

    // ------------------------ INTERFACE METHODS ------------------------

    // --------------------- Interface PrototypeBasedCentroidClusteringMethod ---------------------

    public void setPrototypeFactory(final GenericFactory<T> prototypeFactory) throws GenericFactoryException {
        this.prototypeFactory = prototypeFactory;
    }

    // --------------------- Interface SampleInitializedOnlineClusteringMethod ---------------------

    protected void trainWithKnownTrainingLabels(final ClusterableIterator<T> trainingIterator) {
        throw new NotImplementedException();
    }

    /**
     * {@inheritDoc}
     */
    public void initializeWithSamples(final ClusterableIterator<T> trainingIterator, final int initSamples) //GenericFactory<T> prototypeFactory)
    //   throws ClusterException
    {
        final Map<String, GrowableKmeansClustering<T>> theSubclusteringMap = new HashMap<String, GrowableKmeansClustering<T>>();

        if (predictLabelSets.size() > 1) {
            throw new ClusterRuntimeException(
                    "LabelDecomposingBayesianClustering can't yet handle more than one exclusive label set at a time: "
                            + predictLabelSets.keySet());
        }

        final Set<String> predictLabels = predictLabelSets.values().iterator().next();
        try {
            // BAD consume the entire iterator, ignoring initsamples
            final Multinomial<Cluster<T>> priorsMult = new Multinomial<Cluster<T>>();
            try {
                int i = 0;
                while (true) {

                    final T point = trainingIterator.nextFullyLabelled();

                    final String bestLabel = point.getImmutableWeightedLabels().getDominantKeyInSet(predictLabels);
                    //Cluster<T> cluster = theClusterMap.get(bestLabel);

                    GrowableKmeansClustering<T> theIntraLabelClustering = theSubclusteringMap.get(bestLabel);

                    if (theIntraLabelClustering == null) {
                        theIntraLabelClustering = new GrowableKmeansClustering<T>(measure, potentialTrainingBins,
                                predictLabelSets, prohibitionModel, testLabels);
                        theSubclusteringMap.put(bestLabel, theIntraLabelClustering);
                    }

                    // naive online agglomerative clustering:
                    // add points to clusters in the order they arrive, one pass only, create new clusters as needed

                    // the resulting clustering may suck, but it should still more or less span the space of the inputs,
                    // so it may work well enough for this purpose.

                    // doing proper k-means would be nicer, but then we'd have to store all the training points, or re-iterate them somehow.

                    final ClusterMove<T, CentroidCluster<T>> cm = theIntraLabelClustering.bestClusterMove(point);

                    CentroidCluster<T> cluster = cm.bestCluster;

                    if (cm.bestDistance > unknownDistanceThreshold) {
                        logger.debug("Creating new subcluster (" + cm.bestDistance + " > "
                                + unknownDistanceThreshold + ") for " + bestLabel);
                        cluster = new AdditiveCentroidCluster<T>(i++, prototypeFactory.create());
                        //cluster.setId(i++);

                        // add the new cluster to the local per-label clustering...
                        theIntraLabelClustering.addCluster(cluster);

                        // ... and also to the overall clustering
                        addCluster(cluster);

                        // REVIEW for now we make a uniform prior
                        priorsMult.put(cluster, 1);
                    }
                    cluster.add(point);
                    /*      if(cluster.getLabelCounts().uniqueSet().size() != 1)
                             {
                             throw new Error();
                             }*/

                }
            } catch (NoSuchElementException e) {
                // iterator exhausted
            }
            priorsMult.normalize();

            //         clusterPriors = priorsMult.getValueMap();

            final ImmutableMap.Builder<Cluster<T>, Double> builder = ImmutableMap.builder();
            clusterPriors = builder.putAll(priorsMult.getValueMap()).build();

            //theClusters = theSubclusteringMap.values();

            for (final Map.Entry<String, GrowableKmeansClustering<T>> entry : theSubclusteringMap.entrySet()) {
                final String label = entry.getKey();
                final GrowableKmeansClustering<T> theIntraLabelClustering = entry.getValue();
                if (logger.isInfoEnabled()) {
                    logger.info("Created " + theIntraLabelClustering.getClusters().size() + " clusters from "
                            + theIntraLabelClustering.getN() + " points for " + label);
                }
            }
        } catch (DistributionException e) {
            throw new ClusterRuntimeException(e);
        } catch (GenericFactoryException e) {
            throw new ClusterRuntimeException(e);
        }
    }

    // -------------------------- OTHER METHODS --------------------------

    /**
     * {@inheritDoc}
     */
    /*   public void train(CollectionIteratorFactory<T> trainingCollectionIteratorFactory)
     throws IOException, ClusterException
          {
          //super.train(trainingCollectionIteratorFactory);
        
          //limitToPopulatedClusters();
        
          // after that, normalize the label probabilities
        
          removeEmptyClusters();
          normalizeClusterLabelProbabilities();
          }*/
}