com.joliciel.jochre.stats.DBSCANClusterer.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.jochre.stats.DBSCANClusterer.java

Source

///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2012 Assaf Urieli
//
//This file is part of Jochre.
//
//Jochre is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Jochre is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Jochre.  If not, see <http://www.gnu.org/licenses/>.
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.jochre.stats;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Stack;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Performs clustering on a dataset using the DBSCAN algorithm,
 * and the Euclidean distance between coordinates corresponding to each object.
 * 
 * @author Assaf Urieli
 *
 */
public class DBSCANClusterer<T> {
    private static final Log LOG = LogFactory.getLog(DBSCANClusterer.class);

    List<T> objectSet;
    List<double[]> dataSet;
    boolean[] visited;
    List<Set<T>> clusterList;

    public DBSCANClusterer(List<T> objectSet, List<double[]> dataSet) {
        if (objectSet.size() != dataSet.size())
            throw new RuntimeException("Object Set size has to be the same as Data Set size");
        this.objectSet = objectSet;
        this.dataSet = dataSet;
    }

    public Set<Set<T>> cluster(double epsilon, int minPoints, boolean includeNoise) {
        LOG.debug("cluster: epsilon=" + epsilon + ", minPoints=" + minPoints + ", includeNoise=" + includeNoise);
        /*
        DBSCAN(D, eps, MinPts)
           C = 0
           for each unvisited point P in dataset D
        mark P as visited
        N = getNeighbors (P, eps)
        if sizeof(N) < MinPts
           mark P as NOISE
        else
           C = next cluster
           expandCluster(P, N, C, eps, MinPts)
         */
        Set<Set<T>> clusters = new HashSet<Set<T>>();
        Set<T> noise = new HashSet<T>();
        this.visited = new boolean[dataSet.size()];
        this.clusterList = new ArrayList<Set<T>>(dataSet.size());
        for (int i = 0; i < dataSet.size(); i++)
            this.clusterList.add(null);

        Set<T> cluster = null;

        for (int index = 0; index < dataSet.size(); index++) {
            if (visited[index])
                continue;
            visited[index] = true;
            Set<Integer> neighbours = this.getNeighbours(index, epsilon);
            if (neighbours.size() < minPoints - 1)
                noise.add(objectSet.get(index));
            else {
                cluster = new HashSet<T>();
                expandCluster(index, neighbours, cluster, epsilon, minPoints);
                clusters.add(cluster);
            }
        }
        LOG.debug("Found " + clusters.size() + " clusters");
        LOG.debug("Found " + noise.size() + " noise");
        if (includeNoise) {
            for (T object : noise) {
                Set<T> oneObject = new HashSet<T>();
                oneObject.add(object);
                clusters.add(oneObject);

            }
        }
        return clusters;
    }

    private void expandCluster(int index, Set<Integer> neighbours, Set<T> cluster, double epsilon, int minPoints) {
        /*
        expandCluster(P, N, C, eps, MinPts)
           add P to cluster C
           for each point P' in N 
        if P' is not visited
           mark P' as visited
           N' = getNeighbors(P', eps)
           if sizeof(N') >= MinPts
              N = N joined with N'
        if P' is not yet member of any cluster
           add P' to cluster C
        */
        cluster.add(objectSet.get(index));
        clusterList.set(index, cluster);
        Stack<Integer> points = new Stack<Integer>();
        points.addAll(neighbours);

        while (!points.isEmpty()) {
            int i = points.pop();
            if (!visited[i]) {
                visited[i] = true;
                Set<Integer> nPrime = this.getNeighbours(i, epsilon);
                if (nPrime.size() >= minPoints - 1) {
                    points.addAll(nPrime);
                }
            }
            if (clusterList.get(i) == null) {
                cluster.add(objectSet.get(i));
                clusterList.set(i, cluster);
            }
        }
    }

    /**
     * Get neighbours based on Euclidean distance.
     * @param i
     * @param epsilon
     * @return
     */
    Set<Integer> getNeighbours(int i, double epsilon) {
        Set<Integer> neighbours = new HashSet<Integer>();
        double[] point = dataSet.get(i);
        int dimensions = point.length;
        for (int j = 0; j < dataSet.size(); j++) {
            if (i != j) {
                double[] otherPoint = dataSet.get(j);
                double sum = 0.0;
                for (int n = 0; n < dimensions; n++) {
                    double diff = point[n] - otherPoint[n];
                    sum += (diff * diff);
                }
                double distance = Math.sqrt(sum);
                if (distance <= epsilon)
                    neighbours.add(j);
            }
        }
        return neighbours;
    }
}