edu.utsa.sifter.som.SelfOrganizingMap.java Source code

Java tutorial

Introduction

Here is the source code for edu.utsa.sifter.som.SelfOrganizingMap.java

Source

/**
 *
 * Sifter - Search Indexes for Text Evidence Relevantly
 *
 * Copyright (C) 2013, University of Texas at San Antonio (UTSA)
 *
 * Sifter is a digital forensics and e-discovery tool for conducting
 * text based string searches.  It clusters and ranks search hits
 * to improve investigative efficiency. Hit-level ranking uses a 
 * patent-pending ranking algorithm invented by Dr. Nicole Beebe at UTSA.
 *  
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * @author Jon Stewart, Lightbox Technologies
**/

package edu.utsa.sifter.som;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.TimesFunction;
import org.apache.mahout.math.jet.random.AbstractDistribution;

import java.util.Iterator;
import java.util.Set;
import java.util.HashSet;
import java.util.Random;
import java.util.ArrayList;
import java.util.PriorityQueue;
import java.util.HashMap;

public class SelfOrganizingMap {

    public static class CellStats {
        private int Num;
        private int Region;
        private double SumSquaredDistance;

        private java.util.Vector<String> TopTerms;

        public CellStats() {
            Num = 0;
            SumSquaredDistance = 0;
        }

        public int getNum() {
            return Num;
        }

        public double stdDev() {
            return Math.sqrt(Num > 2 ? SumSquaredDistance / (Num - 1) : SumSquaredDistance);
        }

        public double sumSqrDistance() {
            return SumSquaredDistance;
        }

        public void accumulate(final double dist) {
            ++Num;
            SumSquaredDistance += dist * dist;
        }

        public void setRegion(final int r) {
            Region = r;
        }

        public int getRegion() {
            return Region;
        }

        public java.util.Vector<String> getTopTerms() {
            return TopTerms;
        }

        public void setTopTerms(final java.util.Vector<String> t) {
            TopTerms = t;
        }
    }

    private final DenseMatrix Cells;
    private final double CellFactors[];
    private final double S2[];
    private final int Height, Width, Dims, NumCells;
    private final CellStats Stats[];

    private int NumRegions;
    private int[] RegionColors;
    private ArrayList<Set<Integer>> RegionMap;

    private ArrayList<HashMap<Integer, Integer>> CellTermDifferences;

    private final int NUM_POSSIBLE_REGION_COLORS = 9;

    public SelfOrganizingMap(int height, int width, int dimensions) {
        Height = height;
        Width = width;
        NumCells = Height * Width;
        Dims = dimensions;
        Cells = new DenseMatrix(NumCells, Dims);
        CellFactors = new double[NumCells];
        S2 = new double[NumCells];
        Stats = new CellStats[NumCells];
        NumRegions = 0;
    }

    public void init(final AbstractDistribution prng) {
        Cells.assign(prng);
        for (int i = 0; i < NumCells; ++i) {
            CellFactors[i] = 1.0;
            Stats[i] = new CellStats();
        }
        recalculateS2();
    }

    public void rescale() { // call this after the SOM has converged
        final TimesFunction multiplier = new TimesFunction();
        for (int i = 0; i < NumCells; ++i) {
            Cells.viewRow(i).assign(multiplier, CellFactors[i]);
            CellFactors[i] = 1.0;
        }
    }

    public void assignCell(final int id, final double dist) {
        Stats[id].accumulate(dist);
    }

    public CellStats getStats(final int id) {
        return Stats[id];
    }

    public void recalculateS2() {
        double s2;
        Vector cellWeights;
        double f;
        for (int id = 0; id < NumCells; ++id) {
            cellWeights = getCell(id);
            s2 = 0;
            f = CellFactors[id];
            for (Vector.Element w : cellWeights) {
                s2 += (f * w.get()) * (f * w.get());
            }
            S2[id] = s2;
        }
    }

    public final int getX(final int id) {
        return id % Width;
    }

    public final int getY(final int id) {
        return id / Width;
    }

    public final int getID(final int x, final int y) {
        return y * Width + x;
    }

    public final Vector getCell(final int x, final int y) {
        return Cells.viewRow(getID(x, y));
    }

    public final Vector getCell(final int id) {
        return Cells.viewRow(id);
    }

    public final double getFactor(final int id) {
        return CellFactors[id];
    }

    public final double getS2(final int id) {
        return S2[id];
    }

    public final int width() {
        return Width;
    }

    public final int height() {
        return Height;
    }

    public final int numCells() {
        return NumCells;
    }

    public final int dimensions() {
        return Dims;
    }

    public int getNumRegions() {
        return NumRegions;
    }

    public void setNumRegions(final int n) {
        NumRegions = n;
    }

    public int getRegionColor(final int region) {
        return RegionColors[region];
    }

    public ArrayList<Set<Integer>> getRegionMap() {
        return RegionMap;
    }

    public void getNeighborhood(final int x, final int y, final int radius, final IntArrayWritable neighbors) {
        final int minX = Math.max(x - radius, 0);
        final int minY = Math.max(y - radius, 0);
        final int maxX = Math.min(x + radius, Width - 1);
        final int maxY = Math.min(y + radius, Height - 1);

        neighbors.clear();

        for (int curY = minY; curY <= maxY; ++curY) {
            for (int curX = minX; curX <= maxX; ++curX) {
                neighbors.add(getID(curX, curY));
            }
        }
    }

    public double computeDistance(final int id, final IntArrayWritable doc) {
        final Vector cellWeights = getCell(id);
        final int cardinality = doc.getLength();
        final double s1 = cardinality;
        final double s2 = getS2(id);

        double s3 = 0.0;
        double c = 0.0, // Kahan summation algorithm to account for error, c.f. http://en.wikipedia.org/wiki/Kahan_summation_algorithm
                w, t;

        final double f = getFactor(id);
        final int[] terms = doc.getInts();
        for (int i = 0; i < cardinality; ++i) {
            w = (f * cellWeights.getQuick(terms[i])) - c;
            t = s3 + w;
            c = (t - s3) - w;
            s3 = t;
            // s3 += f * cellWeights.getQuick(itr.next().index());
        }
        s3 *= -2;

        final double d = s1 + s2 + s3;
        if (d < 0) {
            System.out.println("Negative distance on " + id + " - d = " + d + ", s1 = " + s1 + ", s2 = " + s2
                    + ", s3 = " + s3 + ", f = " + f);
        }
        return d;
    }

    public void updateCell(final int id, final double alpha, final IntArrayWritable doc) {
        // Scalable SOM updating, per Roussinov

        final double rate = 1 - alpha;
        final double f = CellFactors[id];
        final double nextF = rate * f; // Rule 5

        final double adjustment = alpha / (rate * CellFactors[id]); // Rule 6
        double sumSqrOld = 0.0;
        double sumSqrNew = 0.0;
        double c1 = 0.0, // Kahan summation algorithm to account for error, c.f. http://en.wikipedia.org/wiki/Kahan_summation_algorithm
                c2 = 0.0, y, t;

        final Vector weights = getCell(id);

        double weight;
        double trueWeight;
        int idx;
        final int[] terms = doc.getInts();
        final int numTerms = doc.getLength();
        for (int i = 0; i < numTerms; ++i) {
            idx = terms[i];
            weight = weights.getQuick(idx);

            trueWeight = weight * f;
            y = (trueWeight * trueWeight) - c1;
            t = sumSqrOld + y; // S'(t+1) component
            c1 = (t - sumSqrOld) - y;
            sumSqrOld = t;
            // sumSqrOld += trueWeight * trueWeight;

            weight += adjustment; // adjust weight

            trueWeight = weight * nextF;

            y = (trueWeight * trueWeight) - c2;
            t = sumSqrNew + y;
            c2 = (t - sumSqrNew) - y;
            sumSqrNew = t;
            // sumSqrNew += trueWeight * trueWeight; // S_2'(t+1) component

            weights.setQuick(idx, weight);
        }
        CellFactors[id] = nextF;
        S2[id] = sumSqrNew + (rate * rate) * (S2[id] - sumSqrOld); // new S2 component
    }

    void assignTopTerms(final int numTopTerms, final java.util.Vector<String> terms) {
        final PriorityQueue<TermPair> topWeights = new PriorityQueue<TermPair>(numTopTerms,
                new TermPair.TermPairComparator());

        for (int i = 0; i < numCells(); ++i) {
            final java.util.Vector<String> topTerms = new java.util.Vector<String>(numTopTerms);
            topTerms.setSize(numTopTerms);

            final Vector cell = getCell(i);
            final double f = getFactor(i);

            topWeights.clear();
            for (Vector.Element w : cell) {
                int val = (int) (1000 * f * w.get());
                if (topWeights.size() < numTopTerms) {
                    topWeights.add(new TermPair(terms.get(w.index()), val));
                } else if (topWeights.peek().DocCount < val) {
                    topWeights.remove();
                    topWeights.add(new TermPair(terms.get(w.index()), val));
                }
            }
            final int numTopWeights = topWeights.size();
            for (int j = numTopWeights - 1; j > -1; --j) {
                topTerms.set(j, topWeights.remove().Term);
            }
            getStats(i).setTopTerms(topTerms);
        }
    }

    void initColorMap(final Set<Integer> colors) {
        colors.clear();
        for (int i = 0; i < NUM_POSSIBLE_REGION_COLORS; ++i) {
            colors.add(i);
        }
    }

    int pickColor(final int region, final Set<Integer> colors, final Integer[] choices, final Random chooser) {
        if (colors.size() > 0) {
            final int choice = chooser.nextInt(colors.size());
            return colors.toArray(choices)[choice];
        } else {
            System.err.println("The set of possible colors for region " + region + " is zero!");
            return chooser.nextInt(NUM_POSSIBLE_REGION_COLORS);
        }
    }

    void colorRegions() {
        RegionColors = new int[NumRegions];
        RegionMap = new ArrayList<Set<Integer>>(NumRegions);
        for (int i = 0; i < NumRegions; ++i) {
            RegionMap.add(new HashSet<Integer>());
            RegionColors[i] = -1;
        }
        final IntArrayWritable neighbors = new IntArrayWritable(9);
        for (int y = 0; y < height(); ++y) {
            for (int x = 0; x < width(); ++x) {
                final int id = getID(x, y);
                final int region = getStats(id).getRegion();
                getNeighborhood(x, y, 1, neighbors);
                for (int i = 0; i < neighbors.getLength(); ++i) {
                    final int neighbor = neighbors.getInts()[i];
                    final int otherRegion = getStats(neighbor).getRegion();
                    if (otherRegion != region) {
                        RegionMap.get(region).add(otherRegion);
                        RegionMap.get(otherRegion).add(region);
                    }
                }
            }
        }
        final Random chooser = new Random(29);
        final Set<Integer> colors = new HashSet<Integer>(NUM_POSSIBLE_REGION_COLORS);
        final Integer[] choices = new Integer[NUM_POSSIBLE_REGION_COLORS];
        for (int i = 0; i < NumRegions; ++i) {
            initColorMap(colors);
            for (Integer adj : RegionMap.get(i)) {
                if (RegionColors[adj] != -1) {
                    colors.remove(RegionColors[adj]);
                }
            }
            RegionColors[i] = pickColor(i, colors, choices, chooser);
        }
    }

    int maxTermDifference(final int tID, final int uID) {
        final Vector t = getCell(tID);
        final Vector u = getCell(uID);
        final Vector diff = t.minus(u);
        final double maxValue = diff.maxValue();
        final double minValue = diff.minValue();

        if (minValue < 0 && Math.abs(minValue) > maxValue) { // even if maxValue is negative, this will hold
            return -diff.minValueIndex();
        } else {
            return diff.maxValueIndex();
        }
    }

    public HashMap<Integer, Integer> getCellTermDiffs(final int cellID) {
        return CellTermDifferences.get(cellID);
    }

    public int getCellTermDiff(final int lhsID, final int rhsID) {
        if (lhsID == rhsID) {
            return 0;
        } else if (lhsID < rhsID) {
            return CellTermDifferences.get(lhsID).get(rhsID);
        } else { // rhsID < lhsID
            return -1 * CellTermDifferences.get(rhsID).get(lhsID);
        }
    }

    void setCellTermDiff(final int diff, final int lhsID, final int rhsID) {
        if (lhsID == rhsID) {
            return;
        } else if (lhsID < rhsID) {
            CellTermDifferences.get(lhsID).put(rhsID, diff);
        } else { // rhsID < lhsID
            CellTermDifferences.get(rhsID).put(lhsID, -1 * diff);
        }
    }

    void assignTermDiffs() {
        // diffs will be stored as ctd[id1][id2], where id1 < id2
        CellTermDifferences = new ArrayList<HashMap<Integer, Integer>>(NumCells);
        for (int i = 0; i < NumCells; ++i) {
            CellTermDifferences.add(new HashMap<Integer, Integer>());
        }
        int id = -1;
        int otherID = -1;
        for (int y = 0; y < Height - 1; ++y) {
            /*
              3 adjacencies involving at least one cell neither in the last row nor
              nor last column.
                --------------
                 |     |     |
                 | x,y *  1  |
                 |     |     |
                ----*--*------
                 |     |     |
                 |  2  |     |
                 |     |     |
                --------------
            */
            for (int x = 0; x < Width - 1; ++x) {
                id = getID(x, y);
                otherID = getID(x + 1, y);
                setCellTermDiff(maxTermDifference(id, otherID), id, otherID);

                otherID = getID(x, y + 1);
                setCellTermDiff(maxTermDifference(id, otherID), id, otherID);

                otherID = getID(x + 1, y + 1);
                setCellTermDiff(maxTermDifference(id, otherID), id, otherID);

                // cross cells, 1-2
                id = getID(x + 1, y);
                otherID = getID(x, y + 1);
                setCellTermDiff(maxTermDifference(id, otherID), id, otherID);
            }
            /*
              Vertical adjacencies for last column.
                --------------
                 |     |     |
                 |     | x,y |
                 |     |     |
                ----------*---  OUT OF BOUNDS
                 |     |     |
                 |     |     |
                 |     |     |
                --------------
            */
            id = getID(Width - 1, y);
            otherID = getID(Width - 1, y + 1);
            setCellTermDiff(maxTermDifference(id, otherID), id, otherID);
        }
        /*
          Horizontal adjacencies for last row.
              --------------
               |     |     |
               |     |     |
               |     |     |
              --------------
               |     |     |
               | x,y *     |
               |     |     |
              --------------
               OUT OF BOUNDS
        */
        for (int x = 0; x < Width - 1; ++x) {
            id = getID(x, Height - 1);
            otherID = getID(x + 1, Height - 1);
            setCellTermDiff(maxTermDifference(id, otherID), id, otherID);
        }
    }
}