edu.berkeley.compbio.ml.cluster.kohonen.KohonenSOM2D.java Source code

Java tutorial

Introduction

Here is the source code for edu.berkeley.compbio.ml.cluster.kohonen.KohonenSOM2D.java

Source

/*
 * Copyright (c) 2006-2013  David Soergel  <dev@davidsoergel.com>
 * Licensed under the Apache License, Version 2.0
 * http://www.apache.org/licenses/LICENSE-2.0
 */

package edu.berkeley.compbio.ml.cluster.kohonen;

import com.davidsoergel.dsutils.DSArrayUtils;
import com.davidsoergel.dsutils.GenericFactory;
import com.davidsoergel.dsutils.GenericFactoryException;
import com.davidsoergel.stats.DissimilarityMeasure;
import com.davidsoergel.stats.SimpleFunction;
import edu.berkeley.compbio.ml.cluster.AbstractUnsupervisedOnlineClusteringMethod;
import edu.berkeley.compbio.ml.cluster.AdditiveClusterable;
import edu.berkeley.compbio.ml.cluster.CentroidCluster;
import edu.berkeley.compbio.ml.cluster.CentroidClusteringUtils;
import edu.berkeley.compbio.ml.cluster.ClusterException;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.ClusterRuntimeException;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.ClusterableIteratorFactory;
import edu.berkeley.compbio.ml.cluster.NoGoodClusterException;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import org.apache.commons.lang.NotImplementedException;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Kohonen Self Organizing Map implementation for a rectangular grid of two dimensions.
 * <p/>
 * The standard algorithm moves the cells towards the winner cell; in our implementation we also have the option to remove the point from the cell where it was previously assigned, moving the
 * neighbors away from the moved node a bit.
 * <p/>
 * Note that because this is an "online" method, we can't do PCA or whatever to initialize the grid.  That's OK; we'll just initialize the grid with a uniform prototype; after placing the first
 * incoming point with a neighborhood encompassing the whole grid, all cells will be differentiated.
 *
 * @author <a href="mailto:dev@davidsoergel.com">David Soergel</a>
 * @version $Id$
 * @Author David Soergel
 * @Version 1.0
 */
public class KohonenSOM2D<T extends AdditiveClusterable<T>>
        extends AbstractUnsupervisedOnlineClusteringMethod<T, KohonenSOMCell<T>> implements KohonenSOM<T> {
    // ------------------------------ FIELDS ------------------------------

    // we jump through some hoops to avoid actually storing the cells in an array,
    // since we don't know a priori how many dimensions it should have, and it would be redundant with
    // OnlineClusteringMethod.theClusters
    // aha: since theClusters is a list, we can map our array into it.  see listIndexFor(int[] cellposition)
    private static final Logger logger = Logger.getLogger(KohonenSOM2D.class);

    // how many cells wide is the grid along each axis
    final int[] cellsPerDimension;
    double maxRadius;
    final double minRadius;

    // the product of the first i dimensions, precomputed for convenience
    //int[] blockSize;

    //   Map<Vector<Integer>, T> centroidsByPosition;

    int time = 0;

    // how many point assignments have changed in this epoch
    int changed = 0;

    //private DistanceMeasure<T> measure;
    private final int dimensions;
    private final boolean edgesWrap;

    private final boolean decrementLosingNeighborhood;

    // how strong the motion should be vs. time
    private final SimpleFunction moveFactorFunction;

    // what radius should be considered vs. time
    private final SimpleFunction radiusFunction;

    // how strong the motion should be vs. fraction of the radius
    private final SimpleFunction weightFunction;
    private final Map<Integer, WeightedMask> weightedMasks = new HashMap<Integer, WeightedMask>();
    private final Map<Integer, WeightedMask> shellMasks = new HashMap<Integer, WeightedMask>();

    private final KohonenSOM2DSearchStrategy<T> searchStrategy;

    private LabelDiffuser<T, KohonenSOMCell<T>> labeler;

    // --------------------------- CONSTRUCTORS ---------------------------

    public KohonenSOM2D(final DissimilarityMeasure<T> dm, final Set<String> potentialTrainingBins,
            final Map<String, Set<String>> predictLabelSets, final ProhibitionModel<T> prohibitionModel,
            final Set<String> testLabels, @NotNull final Integer[] cellsPerDimension,
            final SimpleFunction moveFactorFunction, final SimpleFunction radiusFunction,
            final SimpleFunction weightFunction, final boolean decrementLosingNeighborhood, final boolean edgesWrap,
            final double minRadius, final KohonenSOM2DSearchStrategy<T> searchStrategy) {
        super(dm, potentialTrainingBins, predictLabelSets, prohibitionModel, testLabels);

        this.cellsPerDimension = DSArrayUtils.toPrimitive(cellsPerDimension);

        this.dimensions = cellsPerDimension.length;
        this.moveFactorFunction = moveFactorFunction;
        this.radiusFunction = radiusFunction;
        this.weightFunction = weightFunction;
        this.decrementLosingNeighborhood = decrementLosingNeighborhood;
        this.edgesWrap = edgesWrap;
        this.minRadius = minRadius;
        this.searchStrategy = searchStrategy;

        if (dimensions != 2) {
            throw new ClusterRuntimeException("KohonenSOM2D accepts only two-dimensional grid.");
        }

        // precompute stuff for listIndexFor
        /*
        blockSize = new int[dimensions];
        blockSize[1] = 1;
        blockSize[0] = cellsPerDimension[1];
        */
        final int totalCells = cellsPerDimension[0] * cellsPerDimension[1];

        setNumClusters(totalCells);

        final int[] zeroCell = new int[dimensions];
        Arrays.fill(zeroCell, 0);
        //createClusters(zeroCell, -1, prototype);
        //createClusters(totalCells, prototype);
        //List<Interval<Double>> axisRanges;
        //   initializeClusters(axisRanges);

        maxRadius = DSArrayUtils.norm(this.cellsPerDimension) / 2.;//Math.ceil();

        searchStrategy.setDistanceMeasure(measure);
    }

    // --------------------- GETTER / SETTER METHODS ---------------------

    public int getChanged() {
        return changed;
    }

    public void setLabeler(final LabelDiffuser<T, KohonenSOMCell<T>> labeler) {
        this.labeler = labeler;
    }

    /**
     * empty clusters are essential in the SOM context, so override the removal
     */
    protected void removeEmptyClusters() {

    }
    // ------------------------ INTERFACE METHODS ------------------------

    // --------------------- Interface CentroidClusteringMethod ---------------------

    @Override
    public String shortClusteringStats() {
        return CentroidClusteringUtils.shortClusteringStats(getClusters(), measure);
    }

    public void computeClusterStdDevs(final ClusterableIterator<T> theDataPointProvider) {
        CentroidClusteringUtils.computeClusterStdDevs(getClusters(), measure, getAssignments(),
                theDataPointProvider);
    }

    @Override
    public String clusteringStats() {
        final ByteArrayOutputStream b = new ByteArrayOutputStream();
        CentroidClusteringUtils.writeClusteringStatsToStream(getClusters(), measure, b);
        return b.toString();
    }

    public void writeClusteringStatsToStream(final OutputStream outf) {
        CentroidClusteringUtils.writeClusteringStatsToStream(getClusters(), measure, outf);
    }

    // --------------------- Interface DiffusableLabelClusteringMethod ---------------------

    /**
     * {@inheritDoc}
     */
    public Iterator<Set<KohonenSOMCell<T>>> getNeighborhoodShellIterator(final KohonenSOMCell<T> cell) {
        return new NeighborhoodShellIterator(cell);
    }

    // --------------------- Interface OnlineClusteringMethod ---------------------

    /**
     * @param p
     * @return
     * @throws ClusterException
     * @throws NoGoodClusterException
     */
    public boolean add(final T p) throws NoGoodClusterException {
        // ** this is not synchronized!  I think it's OK, but be careful...
        // that should really only cause trouble if the same point gets added twice and simultaneously, and gets assiged to different clusters.  That seems highly unlikely.

        final ClusterMove<T, KohonenSOMCell<T>> cm = bestClusterMove(p);

        if (cm.isChanged()) {
            changed++;
            putAssignment(p.getId(), cm.bestCluster);
        }

        // do the moves whether or not the assignment changed

        final KohonenSOMCell<T> loser = cm.oldCluster;
        final KohonenSOMCell<T> winner = cm.bestCluster;

        double moveFactor = moveFactorFunction.f(time);

        moveFactor = Math.min(moveFactor, 1);
        moveFactor = Math.max(moveFactor, 0);
        final double radius = getCurrentRadius();

        logger.trace("Adding point with neighborhood radius " + radius + ", moveFactor " + moveFactor);

        // REVIEW decrementLosingNeighborhood has issues
        // yeah a couple things:
        // 1. it produces negative counts, which makes no sense, and
        // 2. it leaves the average count number very low
        // 3. in many of our runs we have an infinite supply of new samples, and never reclassify old samples

        if (decrementLosingNeighborhood && loser != null) {
            winner.getMutableWeightedLabels().removeAll(p.getMutableWeightedLabels());
            for (Iterator<WeightedCell> i = getWeightedMask((int) radius).iterator(loser); i.hasNext();) {
                final WeightedCell v = i.next();
                final KohonenSOMCell<T> neighbor = v.theCell;
                /*T motion = p.minus(neighbor.getCentroid());
                     motion.multiplyBy(-moveFactor);
                     if (v.weight != 1)
                        {
                        motion.multiplyBy(v.weight);
                        }
                     neighbor.recenterByAdding(motion);*/

                final double motionFactor = moveFactor * v.weight;
                neighbor.recenterByRemovingWeighted(p, motionFactor);
            }
        }
        p.doneLabelling();
        winner.getMutableWeightedLabels().addAll(p.getImmutableWeightedLabels()); //p.getMutableWeightedLabels());
        for (Iterator<WeightedCell> i = getWeightedMask((int) radius).iterator(winner); i.hasNext();) {
            final WeightedCell v = i.next();
            final KohonenSOMCell<T> neighbor = v.theCell;

            // REVIEW Rearrange to avoid subtraction
            /*
               T motion = p.minus(neighbor.getCentroid());
                
               motion.multiplyBy(moveFactor);
               if (v.weight != 1)
                  {
                  motion.multiplyBy(v.weight);
                  }
               neighbor.recenterByAdding(motion);
               */

            final double motionFactor = moveFactor * v.weight;

            //neighbor = (1-motionFactor) * neighbor + motionFactor * p;

            // REVIEW does neighbor recentering work right?
            neighbor.recenterByAddingWeighted(p, motionFactor);
        }

        time++;
        return true;
    }

    /**
     * Create a rectangular grid of cells using the given dimensionality and size, assigning a null vector to each
     */
    /*   private void createClusters(int[] cellPosition, int changingDimension, T prototype)
        {
        changingDimension++;
        if (changingDimension == dimensions)
      {
      KohonenSOMCell<T> c = new KohonenSOMCell<T>(measure, prototype.clone());
      theClusters.set(listIndexFor(cellPosition), c);
      }
        else
      {
      for (int i = 0; i < cellsPerDimension[changingDimension]; i++)
         {
         cellPosition[changingDimension] = i;
         createClusters(cellPosition, changingDimension, prototype);
         }
      }
    */
    /*  for (int i = 0; i < k; i++) {
          // initialize the clusters with the first k points
        
          Cluster<T> c = new AdditiveCluster<T>(measure);
          c.setId(i);
        
          theClusters.add(c);
       }
       logger.debug("initialized " + k + " clusters");*/
    //      }
    @Override
    public void train(final ClusterableIteratorFactory<T> trainingCollectionIteratorFactory, final int iterations)
            throws ClusterException {
        super.train(trainingCollectionIteratorFactory, iterations);
        labeler.propagateLabels(this);
        doneLabellingClusters();
    }

    // --------------------- Interface PrototypeBasedCentroidClusteringMethod ---------------------

    /*
     private void initializeClusters(List<Interval<Double>> axisRanges)
        {
        
        }
    */

    public void setPrototypeFactory(final GenericFactory<T> prototypeFactory) throws GenericFactoryException {
        final int totalCells = cellsPerDimension[0] * cellsPerDimension[1];
        createClusters(totalCells, prototypeFactory);

        searchStrategy.setSOM(this);
    }

    // --------------------- Interface SampleInitializedOnlineClusteringMethod ---------------------

    /**
     * {@inheritDoc}
     */
    public void initializeWithSamples(final ClusterableIterator<T> initIterator, final int initSamples) {
        //createClusters(prototypeFactory);

        for (int i = 0; i < initSamples; i++)
        //int i = 0;
        //while(initIterator.hasNext())
        {
            addToRandomCell(initIterator.nextFullyLabelled());
            if (i % 100 == 0) {
                logger.debug("Initialized with " + i + " samples.");
            }
            //   i++;
        }
    }

    // -------------------------- OTHER METHODS --------------------------

    public void addToRandomCell(final T p) {
        final KohonenSOMCell<T> winner = (KohonenSOMCell<T>) chooseRandomCluster();

        final double moveFactor = .5;
        final double radius = maxRadius;

        logger.trace("Adding point with neighborhood radius " + radius + ", moveFactor " + moveFactor);

        // winner.addLabel(p);  // no, this is just for random initialization

        for (Iterator<WeightedCell> i = getWeightedMask((int) radius).iterator(winner); i.hasNext();) {
            final WeightedCell v = i.next();
            final KohonenSOMCell<T> neighbor = v.theCell;

            final double motionFactor = moveFactor * v.weight;
            neighbor.recenterByAddingWeighted(p, motionFactor);
        }
        //time++;  // no!
    }

    /**
     * {@inheritDoc}
     */
    /*   @Override
     public Cluster<T> getBestCluster(T p, List<Double> secondBestDistances)
      throws ClusterException, NoGoodClusterException
        {
        ClusterMove cm = bestClusterMove(p);
        return cm.bestCluster;
        }
    */

    /**
     * {@inheritDoc}
     */
    @Override
    public ClusterMove<T, KohonenSOMCell<T>> bestClusterMove(final T p) throws NoGoodClusterException {
        return searchStrategy.bestClusterMove(p);
    }

    /**
     * assumes inputs are entirely positive and within the bounds given by cellsPerDimension
     *
     * @return
     */
    /*   private int listIndexFor(int[] cellposition)
        {
        int result = 0;
        assert cellposition.length == cellsPerDimension.length;
        for (int i = 0; i < dimensions; i++)
      {
      result += cellposition[i] * blockSize[i];
      }
        return result;
        }
    */
    public double[] computeCellAverageNeighborDistances() {
        final double[] result = new double[getNumClusters()];

        // assume the distances are symmetric, so we only calculate them once per pair of cells

        // we just average the four straight-line distances per cell (no diagonals)

        final int width = cellsPerDimension[0];
        final int height = cellsPerDimension[1];
        for (int x = 0; x < width - 1; x++) {
            for (int y = 0; y < height - 1; y++) {
                //if (x != width && y != height)
                //   {
                final CentroidCluster<T> here = clusterAt(x, y);

                final CentroidCluster<T> right = clusterAt(x + 1, y);
                final double d = measure.distanceFromTo(here.getCentroid(), right.getCentroid());//here.distanceToCentroid(right.getCentroid());

                result[listIndexFor(x, y)] += d;
                result[listIndexFor(x + 1, y)] += d;

                final CentroidCluster<T> down = clusterAt(x, y + 1);
                final double d1 = measure.distanceFromTo(here.getCentroid(), down.getCentroid());

                result[listIndexFor(x, y)] += d1;
                result[listIndexFor(x, y + 1)] += d1;
                //   }
            }
        }

        for (int i = 0; i < result.length; i++) {
            // if the edges don't wrap, then the edge cells should be divided by 3, not 4.  Oh well.

            result[i] /= 4;
        }

        return result;//Arrays.asList(ArrayUtils.toObject(result));
    }

    public KohonenSOMCell<T> clusterAt(final int x, final int y) {
        return getCluster(listIndexFor(x, y));
    }

    /**
     * assumes inputs are entirely positive and within the bounds given by cellsPerDimension
     *
     * @return
     */

    // dumb column-major version

    /*   private int listIndexFor(int x, int y)//int[] cellposition)
          {
          if (edgesWrap)
     {
     x %= cellsPerDimension[0];
     y %= cellsPerDimension[1];
     }
        
          return x * blockSize[0] + y;
          }
        
       private int[] cellPositionFor(int listIndex)
          {
          final int[] result = new int[dimensions];
          for (int i = 0; i < dimensions; i++)
     {
     result[i] = listIndex / blockSize[i];
     listIndex = listIndex % blockSize[i];
     }
          return result;
          }
    */
    private int listIndexFor(int x, int y)//int[] cellposition)
    {
        if (edgesWrap) {
            x %= cellsPerDimension[0];
            y %= cellsPerDimension[1];
        }

        return y * cellsPerDimension[0] + x;
    }

    private int[] cellPositionFor(int listIndex) {
        int x = listIndex % cellsPerDimension[0];
        int y = listIndex / cellsPerDimension[0];
        return new int[] { x, y };
    }

    private void createClusters(final int totalCells, final GenericFactory<T> prototypeFactory)
            throws GenericFactoryException {
        for (int i = 0; i < totalCells; i++) {
            final T centroid = prototypeFactory == null ? null : prototypeFactory.create(String.valueOf(i));
            centroid.doneLabelling();
            final KohonenSOMCell<T> c = new KohonenSOMCell<T>(i, centroid);
            //   c.setId(i);
            addCluster(c);
        }
    }

    public double getCurrentRadius() {
        double radius = radiusFunction.f(time);

        radius = Math.min(radius, maxRadius);
        radius = Math.max(radius, minRadius);

        return radius;
    }

    WeightedMask getShellMask(final int radius) {
        WeightedMask result = shellMasks.get(radius);
        if (result == null) {
            if (radius < 1) {
                result = getWeightedMask(0);
            } else {
                final WeightedMask outerMask = getWeightedMask(radius);
                final WeightedMask innerMask = getWeightedMask(radius - 1);
                final List<Integer> xList = new ArrayList<Integer>();
                final List<Integer> yList = new ArrayList<Integer>();
                for (int i = 0; i < outerMask.deltaX.length; i++) {
                    final int x = outerMask.deltaX[i];
                    final int y = outerMask.deltaY[i];
                    if (!innerMask.containsPoint(x, y)) {
                        xList.add(x);
                        yList.add(y);
                    }
                }

                result = new WeightedMask();
                result.deltaX = DSArrayUtils.toPrimitive(xList.toArray(new Integer[xList.size()]));
                result.deltaY = DSArrayUtils.toPrimitive(yList.toArray(new Integer[yList.size()]));
                result.weight = new double[result.deltaX.length];
                Arrays.fill(result.weight, 1);
                result.numCells = result.deltaX.length;
            }
            shellMasks.put(radius, result);
        }
        return result;
    }

    WeightedMask getWeightedMask(final int radius) {
        WeightedMask result = weightedMasks.get(radius);
        if (result == null) {
            result = new WeightedMask(radius);
            weightedMasks.put(radius, result);
        }
        return result;
    }

    /*
     public Set<Cluster<T>> watershedClustering(double threshold)
        {
        double[] uMatrix = computeCellAverageNeighborDistances();
        
        int width = cellsPerDimension[0];
        int height = cellsPerDimension[1];
        for (int x = 0; x < width; x++)
      {
      for (int y = 0; y < height; y++)
         {
         if (uMatrix[listIndexFor(x, y)] < threshold)
            {
        
            }
         }
      }
        }
    */

    public void resetChanged() {
        changed = 0;
    }

    public void train(final ClusterableIteratorFactory<T> trainingCollectionIteratorFactory,
            final GenericFactory<T> prototypeFactory, final int trainingEpochs) throws ClusterException {
        train(trainingCollectionIteratorFactory, trainingEpochs);
    }

    // -------------------------- INNER CLASSES --------------------------

    /**
     * Iterates over all the cells within a given radius of a center cell, using a fast algorithm from http://homepage.smc.edu/kennedy_john/BCIRCLE.PDF
     * <p/>
     * We can probably speed this up further by caching the results (circle masks, basically).
     * <p/>
     * Note we can't just return the cell, because if the edges wrap, then the same cell may be returned up to four times, but with different distances; so we need to return the distance too.
     */
    class WeightedMask {
        // ------------------------------ FIELDS ------------------------------

        // store a list of x,y pairs representing all the vectors from the center
        int[] deltaX;
        int[] deltaY;

        // cache the weights associated with each x, y pair very inefficiently for easy & fast access
        double[] weight;

        // we don't know exactly how many pixels will be in the circle in advance, so we'll allocate somewhat more
        // memory to the above arrays than we need.  Then we need to keep track of the highest index that is valid,
        // i.e. the logical end of the array as opposed to the physical end.
        int numCells;

        // --------------------------- CONSTRUCTORS ---------------------------

        private WeightedMask() {
        }

        private WeightedMask(final int radius) {
            if (radius == 0) {
                deltaX = new int[1];
                deltaY = new int[1];
                weight = new double[1];

                deltaX[0] = 0;
                deltaY[0] = 0;
                weight[0] = weightFunction == null ? 1 : weightFunction.f(0);
                numCells = 1;
            } else {
                int x = radius;
                int xChange = 1 - 2 * radius;

                // we'll see if 3.2 is enough, given rounding, to work for small r.
                // If it's not then we'll get an ArrayIndexOutOfBoundsException below.
                // 4 should be absolutely safe (i.e., the whole square) and the memory cost is likely no issue anyway.
                final int overestimateNumCells = (int) (3.2 * ((radius + 1) * (radius + 1)));

                deltaX = new int[overestimateNumCells];
                deltaY = new int[overestimateNumCells];
                weight = new double[overestimateNumCells];

                // always add the center (only once)
                int i = 0;
                deltaX[i] = 0;
                deltaY[i] = 0;
                weight[i] = weightFunction == null ? 1 : weightFunction.f(0);
                assert weight[i] > 0;
                i++;

                int radiusError = 0;
                int yChange = 1;
                int y = 0;
                while (x >= y) {
                    i = plot8CirclePoints(i, x, y, radius);

                    y++;
                    radiusError += yChange;
                    yChange += 2;
                    if (2 * radiusError + xChange > 0) {
                        x--;
                        radiusError += xChange;
                        xChange += 2;
                    }
                }

                // for some reason the above algorithm always leaves out the points (0,radius) and (0,-radius).
                // I don't bother to understand why, I just fix it.

                final double theWeight = weightFunction == null ? 1 : weightFunction.f(1);
                assert theWeight > 0;

                deltaX[i] = 0;
                deltaY[i] = radius;
                weight[i] = theWeight;
                i++;

                deltaX[i] = 0;
                deltaY[i] = -radius;
                weight[i] = theWeight;
                i++;

                numCells = i;
            }
        }

        private int plot8CirclePoints(int i, int x, final int y, final int radius) {
            // here we're given x,y pairs around the circumference of the circle; but actually we need to fill it in.

            // also the given x,y pairs are only in one eighth of the circle, so we have to be careful how we fill in the rest
            // to avoid hitting the same internal cells multiple times.

            // note that the circle is always centered around a cell, so the width and height will always be odd.
            // that is, there will always be cells along the axes (x=0 and y=0) which should not be double-counted.

            // we could add pixels to a non-redundant Set of some sort, but that would be slow; we want to do everything with int arrays,
            // so we have to be careful to avoid redundancy up front.

            if (x != 0 && y != 0) {
                // don't use x >= y, because then we'd double-count the diagonals and the center
                for (; x > y; x--) {
                    final double dist = Math.sqrt(x * x + y * y);
                    final double theWeight = weightFunction == null ? 1 : weightFunction.f(dist / (double) radius);
                    /*      if (logger.isDebugEnabled())
                       {
                       logger.debug("Plotting circle point " + x + ", " + y + " distance " + dist + " radius " + radius
                    + " weight " + theWeight);
                       }*/
                    assert theWeight > 0;

                    deltaX[i] = x;
                    deltaY[i] = y;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = x;
                    deltaY[i] = -y;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = -x;
                    deltaY[i] = y;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = -x;
                    deltaY[i] = -y;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = y;
                    deltaY[i] = x;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = y;
                    deltaY[i] = -x;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = -y;
                    deltaY[i] = x;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = -y;
                    deltaY[i] = -x;
                    weight[i] = theWeight;
                    i++;
                }

                // count the four diagonals (x = y) only once.  Note y != 0 so we're not worried about the center.
                // we know that y is incremented exactly once for each call of plot8CirclePoints.

                final double dist = Math.sqrt(y * y + y * y);
                double theWeight = weightFunction == null ? 1 : weightFunction.f(dist / (double) radius);
                assert theWeight > 0;

                deltaX[i] = y;
                deltaY[i] = y;
                weight[i] = theWeight;
                i++;
                deltaX[i] = y;
                deltaY[i] = -y;
                weight[i] = theWeight;
                i++;
                deltaX[i] = -y;
                deltaY[i] = y;
                weight[i] = theWeight;
                i++;
                deltaX[i] = -y;
                deltaY[i] = -y;
                weight[i] = theWeight;
                i++;

                // count the vertical center line only once.
                // again we know this will be called once for each y (where y != 0 due to the conditional block we're in)
                final double theWeight1 = weightFunction == null ? 1
                        : weightFunction.f((double) y / (double) radius);
                assert theWeight1 > 0;

                deltaX[i] = 0;
                deltaY[i] = y;
                weight[i] = weightFunction == null ? 1 : weightFunction.f(theWeight1);
                i++;
                deltaX[i] = 0;
                deltaY[i] = -y;
                weight[i] = weightFunction == null ? 1 : weightFunction.f(theWeight1);
                i++;
            } else if (y == 0 && x != 0) {
                // count the horizontal center line only once

                // don't use x >= y, because then we'd double-count the center
                for (; x > 0; x--) {
                    final double theWeight = weightFunction == null ? 1
                            : weightFunction.f((double) x / (double) radius);
                    assert theWeight > 0;

                    deltaX[i] = x;
                    deltaY[i] = 0;
                    weight[i] = theWeight;
                    i++;

                    deltaX[i] = -x;
                    deltaY[i] = 0;
                    weight[i] = theWeight;
                    i++;
                }
            } else
            // (x == 0 && y != 0)
            {
                // ignore this situation, since the vertical line and the center are already accounted for by our construction above
            }
            return i;
        }

        // -------------------------- OTHER METHODS --------------------------

        public boolean containsPoint(final int x, final int y) {
            //brute force search
            for (int i = 0; i < deltaX.length; i++) {
                if (deltaX[i] == x && deltaY[i] == y) {
                    return true;
                }
            }
            return false;
        }

        public Iterator<WeightedCell> iterator(final KohonenSOMCell<T> center) {
            return new MaskIterator(center);
        }

        // -------------------------- INNER CLASSES --------------------------

        private class MaskIterator implements Iterator<WeightedCell> {
            // ------------------------------ FIELDS ------------------------------

            WeightedCell currentCell, nextCell;
            final int xCenter;
            final int yCenter;

            // the current index in the mask list.  Points to the next cell to be returned.
            int trav = -1;

            // --------------------------- CONSTRUCTORS ---------------------------

            public MaskIterator(final KohonenSOMCell<T> center) {
                //this.center = center;
                // PERF
                final int[] c = cellPositionFor(getClusterIndexOf(center));
                xCenter = c[0];
                yCenter = c[1];
                nextCell = findNextCell();
            }

            @Nullable
            private WeightedCell findNextCell() {
                trav++;

                boolean foundCell = false;

                int realX = -1, realY = -1;// the loop below can't complete without setting these; if there's a bug we'll get ArrayIndexOutOfBoundsException

                // iterate rather than recurse to avoid huge stacks
                while (!foundCell) {
                    if (trav >= numCells) {
                        return null;
                    } else {
                        realX = xCenter + deltaX[trav];
                        realY = yCenter + deltaY[trav];

                        if (!edgesWrap && (realX < 0 || realX >= cellsPerDimension[0] || realY < 0
                                || realY >= cellsPerDimension[1])) {
                            // foundCell still false, try again
                            trav++;
                        } else {
                            realX %= cellsPerDimension[0];
                            // avoid negatives too
                            if (realX < 0) {
                                realX += cellsPerDimension[0];
                            }

                            realY %= cellsPerDimension[1];
                            // avoid negatives too
                            if (realY < 0) {
                                realY += cellsPerDimension[1];
                            }

                            foundCell = true;
                        }
                    }
                }

                return new WeightedCell(getCluster(listIndexFor(realX, realY)), weight[trav]);
            }

            // ------------------------ INTERFACE METHODS ------------------------

            // --------------------- Interface Iterator ---------------------

            /**
             * {@inheritDoc}
             */
            public boolean hasNext() {
                return nextCell != null;
            }

            /**
             * {@inheritDoc}
             */
            public WeightedCell next() {
                currentCell = nextCell;

                nextCell = findNextCell();

                return currentCell;
            }

            /**
             * {@inheritDoc}
             */
            public void remove() {
                throw new NotImplementedException();
            }
        }
    }

    class WeightedCell {
        // ------------------------------ FIELDS ------------------------------

        final KohonenSOMCell<T> theCell;
        final double weight;

        // --------------------------- CONSTRUCTORS ---------------------------

        private WeightedCell(final KohonenSOMCell<T> theCell, final double weight) {
            this.theCell = theCell;
            this.weight = weight;
        }
    }

    /*
        private class OldBogusNeighborhoodIterator implements Iterator<KohonenSOMCell<T>>
           {
           double radius;
        
           KohonenSOMCell<T> center;
           // this will likely need optimizing later
           Set<KohonenSOMCell<T>> todo = new HashSet<KohonenSOMCell<T>>();
           Set<KohonenSOMCell<T>> done = new HashSet<KohonenSOMCell<T>>();
        
        
           private NeighborhoodIterator(KohonenSOMCell<T> center, int time)
              {
              this.center = center;
              radius = radiusFunction.f(time);
              todo.add(center);
              }
        
           public boolean hasNext()
              {
              return !todo.isEmpty();
              }
        
           public KohonenSOMCell<T> next()
              {
              KohonenSOMCell<T> trav = todo.iterator().next();
              done.add(trav);
        
              computeImmediateNeighbors(trav);
              for (KohonenSOMCell<T> neighbor : immediateNeighbors)
                 {
                 // careful not to repeat cells when the radius is large
                 // no problem, the done list deals with that
        
                 // optimizations possible here, i.e. test squares inscribed in circle first before doing sqrt
                 if (neighbor != null && euclideanDistance(neighbor, center) <= radius && !done.contains(neighbor))
                    {
                    todo.add(neighbor);
                    }
                 }
              return trav;
              }
        
           private double euclideanDistance(KohonenSOMCell<T> neighbor, KohonenSOMCell<T> center)
              {
              int[] a = cellPositionFor(theClusters.indexOf(neighbor));
              int[] b = cellPositionFor(theClusters.indexOf(center));
        
              int sum = 0;
              for (int i = 0; i < dimensions; i++)
                 {
                 int dist = a[i] - b[i];
                 if (edgesWrap)
                    {
                    dist = Math.min(dist, b[i] - a[i]);
                    }
                 sum += dist * dist;
                 }
              return Math.sqrt(sum);
              }
        
        
           public void remove()
              {
              throw new NotImplementedException();
              }
           }*/

    private class NeighborhoodShellIterator implements Iterator<Set<KohonenSOMCell<T>>> {
        // ------------------------------ FIELDS ------------------------------

        //   WeightedMask oldMask, currentMask;
        int radius = 0;
        private final KohonenSOMCell<T> center;

        // --------------------------- CONSTRUCTORS ---------------------------

        public NeighborhoodShellIterator(final KohonenSOMCell<T> center) {
            this.center = center;
        }

        // ------------------------ INTERFACE METHODS ------------------------

        // --------------------- Interface Iterator ---------------------

        /**
         * {@inheritDoc}
         */
        public boolean hasNext() {
            return true;
        }

        /**
         * {@inheritDoc}
         */
        // PERF This is horribly inefficient but we don't do it often
        public Set<KohonenSOMCell<T>> next() {
            final WeightedMask mask = getShellMask(radius);
            //   oldMask = currentMask;
            //   currentMask = getWeightedMask(radius);

            final Set<KohonenSOMCell<T>> result = new HashSet<KohonenSOMCell<T>>();
            for (Iterator<WeightedCell> i = mask.iterator(center); i.hasNext();) {
                result.add(i.next().theCell);
            }
            /*   for (Iterator<WeightedCell> i = currentMask.iterator(center); i.hasNext();)
               {
               result.add(i.next().theCell);
               }
            if (oldMask != null)
               {
               for (Iterator<WeightedCell> i = oldMask.iterator(center); i.hasNext();)
                  {
                  result.remove(i.next().theCell);
                  }
               }*/
            radius++;
            return result;
        }

        /**
         * {@inheritDoc}
         */
        public void remove() {
            throw new NotImplementedException();
        }
    }
}