oculus.xdataht.clustering.ClusterResults.java Source code

Java tutorial

Introduction

Here is the source code for oculus.xdataht.clustering.ClusterResults.java

Source

/**
 * Copyright (c) 2013 Oculus Info Inc.
 * http://www.oculusinfo.com/
 *
 * Released under the MIT License.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of
 * this software and associated documentation files (the "Software"), to deal in
 * the Software without restriction, including without limitation the rights to
 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
 * of the Software, and to permit persons to whom the Software is furnished to do
 * so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
package oculus.xdataht.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import oculus.xdataht.data.DataRow;
import oculus.xdataht.data.DataTable;
import oculus.xdataht.data.DenseDataTable;
import oculus.xdataht.data.TableDB;
import oculus.xdataht.data.TableDistribution;
import oculus.xdataht.model.RestLinkCriteria;
import oculus.xdataht.util.Pair;

import org.json.JSONArray;
import org.json.JSONException;

import com.oculusinfo.ml.Instance;
import com.oculusinfo.ml.unsupervised.cluster.Cluster;
import com.oculusinfo.ml.unsupervised.cluster.ClusterResult;

public class ClusterResults implements Serializable {
    /**
     * 
     */
    private static final long serialVersionUID = 1L;
    private Map<String, Set<String>> _clusteringResult = new HashMap<String, Set<String>>();
    private final String _datasetName;
    private final List<Pair<String, Double>> _clusterParameters;
    private final String _clusterParamtersString;

    public ClusterResults(String datasetName, String clusterParamtersString) {
        _datasetName = datasetName;
        _clusterParameters = new ArrayList<Pair<String, Double>>();
        _clusterParamtersString = clusterParamtersString;
    }

    public ClusterResults(String datasetName, ClusterResult mlClusterResult, List<Pair<String, Double>> params,
            String clusterParamtersString) {
        _clusterParamtersString = clusterParamtersString;

        // Convert ml result into hashmap of membership
        for (Cluster c : mlClusterResult) {
            Set<String> members = new HashSet<String>();
            for (Instance inst : c.getMembers()) {
                members.add(inst.getId());
            }
            _clusteringResult.put(c.getId(), members);
        }

        _datasetName = datasetName;
        _clusterParameters = params;
        sortClusterParams();
    }

    public HashMap<String, String> getSummary() {
        HashMap<String, String> summary = new HashMap<String, String>();

        // Total clusters
        summary.put("Total Clusters", _clusteringResult.size() + "");

        // Max/Min
        int maxSize = Integer.MIN_VALUE;
        int minSize = Integer.MAX_VALUE;
        int totalSize = 0;
        for (String id : _clusteringResult.keySet()) {
            Set<String> members = _clusteringResult.get(id);
            if (members.size() > maxSize) {
                maxSize = members.size();
            }
            if (members.size() < minSize) {
                minSize = members.size();
            }
            totalSize += members.size();
        }
        double averageSize = (double) totalSize / _clusteringResult.size();

        summary.put("Minimum Cluster Size", minSize + "");
        summary.put("Maximum Cluster Size", maxSize + "");
        summary.put("Average Cluster Size", averageSize + "");

        return summary;
    }

    private void sortClusterParams() {
        Collections.sort(_clusterParameters, new Comparator<Pair<String, Double>>() {
            public int compare(Pair<String, Double> o1, Pair<String, Double> o2) {
                return o1.getSecond().compareTo(o2.getSecond());
            }
        });
    }

    public ClusterResults(String datasetName, Map<String, Set<String>> clusteringResult,
            String clusterParamtersString) {
        _clusterParamtersString = clusterParamtersString;
        _datasetName = datasetName;
        _clusteringResult = clusteringResult;
        _clusterParameters = new ArrayList<Pair<String, Double>>();

        // Build param list from the supplied JSON string
        String[] paramsSplit = clusterParamtersString.split("-");
        try {
            JSONArray attrsJson = new JSONArray(paramsSplit[1]);
            ArrayList<String> attrs = new ArrayList<String>();
            for (int i = 0; i < attrsJson.length(); i++) {
                attrs.add(attrsJson.getString(i));
            }

            JSONArray weightsJson = new JSONArray(paramsSplit[2]);
            ArrayList<Double> weights = new ArrayList<Double>();
            for (int i = 0; i < weightsJson.length(); i++) {
                weights.add(weightsJson.getDouble(i));
            }

            if (attrs.size() != weights.size())
                throw new IllegalArgumentException();

            for (int i = 0; i < attrs.size(); ++i) {
                _clusterParameters.add(new Pair<String, Double>(attrs.get(i), weights.get(i)));
            }
        } catch (JSONException e) {
            throw new IllegalArgumentException();
        }

        sortClusterParams();
    }

    public List<Pair<String, Double>> getClusterParameters() {
        return _clusterParameters;
    }

    public Map<String, Set<String>> getClustersById() {
        return _clusteringResult;
    }

    public Set<String> getMembers(String id) {
        return _clusteringResult.get(id);
    }

    private static HashMap<String, Integer> getFeatureCounts(DataTable table, String key, Set<String> c) {
        HashMap<String, Integer> valueCounts = new HashMap<String, Integer>();
        if (table != null) {
            for (String memberId : c) {
                DataRow originalRow = table.getRowById(memberId);
                if (originalRow != null) {
                    String value = originalRow.get(key);
                    if (value != null)
                        value = value.trim().toLowerCase();
                    if (!(value == null || value.startsWith("null") || value.length() == 0)) {
                        Integer currentCount = valueCounts.get(value);
                        if (currentCount == null) {
                            currentCount = 0;
                        }
                        currentCount++;
                        valueCounts.put(value, currentCount);
                    }
                }
            }
        }
        return valueCounts;
    }

    private static int getCommonPropertyValueCount(DataTable table, Set<String> a, Set<String> b, String key,
            HashSet<String> ignoreValues) {

        HashMap<String, Integer> aCounts = getFeatureCounts(table, key, a);
        HashMap<String, Integer> bCounts = getFeatureCounts(table, key, b);
        HashMap<String, Integer> commonValues = new HashMap<String, Integer>();

        // Create a map of all common keys between aCounts and bCounts
        Set<String> allKeys = new HashSet<String>(aCounts.keySet());
        allKeys.addAll(bCounts.keySet());

        // Get counts of all common attributes
        for (String value : allKeys) {
            if (ignoreValues != null && ignoreValues.contains(value))
                continue;
            Integer bCount = bCounts.get(value);
            Integer aCount = aCounts.get(value);
            if (aCount != null && bCount != null) {
                commonValues.put(value, Math.min(aCount, bCount));
            }
        }

        // Return the sum of all count values
        Integer count = 0;
        for (String value : commonValues.keySet()) {
            count += commonValues.get(value);
        }
        return count;
    }

    public List<String> filter(Map<String, Integer> excludedClusters) {
        List<String> result = new ArrayList<String>();
        for (String clusterId : _clusteringResult.keySet()) {
            if (!excludedClusters.containsKey(clusterId)) {
                result.add(clusterId);
            }
        }
        return result;
    }

    public List<String> dbFilter(ArrayList<LinkFilter> filters) {
        List<String> result = new ArrayList<String>();
        String where = "where ";
        boolean doSqlSelect = false;
        boolean testClusterSize = false;
        double sizeThreshold = 0;
        LinkFilter.Condition sizeCondition = LinkFilter.Condition.LT;
        boolean isFirst = true;
        ArrayList<LinkFilter> tagFilters = new ArrayList<LinkFilter>();
        for (LinkFilter filter : filters) {
            if (filter.filterAttribute.equals("Cluster Size")) {
                testClusterSize = true;
                sizeThreshold = Double.parseDouble(filter.value);
                sizeCondition = filter.condition;
            } else if (filter.filterAttribute.equals("tag")) {
                doSqlSelect = true;
                tagFilters.add(filter); // defer processing of tag filters until others are done
            } else {
                doSqlSelect = true;
                if (isFirst)
                    isFirst = false;
                else
                    where += " AND ";
                where += filter.getWhereClause();
            }
        }

        if (tagFilters.size() > 0) {
            if (!where.equals("where ")) {
                where += " AND ";
            }
            where += "( ";
            for (int i = 0; i < tagFilters.size() - 1; i++) {
                where += tagFilters.get(i).getWhereClause() + " OR ";
            }
            where += tagFilters.get(tagFilters.size() - 1).getWhereClause() + ")";
        }

        List<String> matchingAds = doSqlSelect
                ? TableDB.getInstance().getMatches(_datasetName, where, tagFilters.size() > 0)
                : null;
        for (String clusterId : _clusteringResult.keySet()) {
            Set<String> members = _clusteringResult.get(clusterId);
            if (testClusterSize) {
                int size = members.size();
                if (!LinkFilter.testNumber(size, sizeThreshold, sizeCondition)) {
                    continue;
                } else if (!doSqlSelect) {
                    result.add(clusterId);
                }
            }
            if (doSqlSelect) {
                for (String member : members) {
                    if (matchingAds.contains(member)) {
                        result.add(clusterId);
                        break;
                    }
                }
            }
        }
        return result;
    }

    public List<String> filter(ArrayList<LinkFilter> filters) {
        List<String> result = new ArrayList<String>();
        ArrayList<String> columns = new ArrayList<String>();
        for (LinkFilter filter : filters) {
            if (!(filter.filterAttribute.equals("Cluster Size") || filter.filterAttribute.equals("tag"))) {
                columns.add(filter.filterAttribute);
            }
        }
        DataTable table = TableDB.getInstance().getDataTableColumns(_datasetName, columns);
        for (String clusterId : _clusteringResult.keySet()) {
            Set<String> members = _clusteringResult.get(clusterId);
            boolean addCluster = true;
            for (LinkFilter lf : filters) {
                addCluster &= lf.testCluster(table, members, _datasetName);
                if (!addCluster) {
                    break;
                }
            }
            if (addCluster) {
                result.add(clusterId);
            }

        }
        return result;
    }

    public Map<String, List<ClusterLink>> getConnectivity(Iterable<String> srcClusters,
            Iterable<String> dstClusters, String datasetName, ArrayList<RestLinkCriteria> linkCriteria,
            boolean ignoreCommon) {
        Map<String, List<ClusterLink>> links = new HashMap<String, List<ClusterLink>>();
        Set<String> addedLinks = new HashSet<String>();

        ArrayList<String> allLinkKeys = new ArrayList<String>();
        for (RestLinkCriteria rlc : linkCriteria) {
            for (String attr : rlc.getAttributes()) {
                if (allLinkKeys.indexOf(attr) == -1) {
                    allLinkKeys.add(attr);
                }
            }
        }
        if (allLinkKeys.size() == 0)
            return links;
        DataTable table = TableDB.getInstance().getDataTableColumns(datasetName, allLinkKeys);

        HashMap<String, HashSet<String>> ignoreValues = new HashMap<String, HashSet<String>>();
        if (ignoreCommon) {
            // Create sets of values to ignore (because they are too common) for each linkBy attribute
            for (String key : allLinkKeys) {
                TableDistribution td = TableDB.getInstance().getValueCounts(datasetName, key);
                for (Map.Entry<String, Integer> entry : td.distribution.entrySet()) {
                    if (entry.getValue() > 20) {
                        HashSet<String> ignore = ignoreValues.get(key);
                        if (ignore == null) {
                            ignore = new HashSet<String>();
                            ignoreValues.put(key, ignore);
                        }
                        ignore.add(entry.getKey());
                    }
                }
            }
        }

        double sum = 0.0;
        for (RestLinkCriteria rlc : linkCriteria) {
            sum += rlc.getWeight();
        }
        HashMap<RestLinkCriteria, Double> normalizedWeights = new HashMap<RestLinkCriteria, Double>();
        for (RestLinkCriteria rlc : linkCriteria) {
            normalizedWeights.put(rlc, rlc.getWeight() / sum);
        }

        for (String srcId : srcClusters) {
            for (String dstId : dstClusters) {

                if (srcId.equals(dstId)) {
                    continue;
                }

                // Each set of attributes in the RestLinkCriteria must be true for a pair of clusters to be connected
                // Create a link for each RestLinkCriteria (if they pass the connectivity test)
                for (RestLinkCriteria rlc : linkCriteria) {

                    // Make sure we don't double up links
                    if (addedLinks.contains(srcId + "-" + dstId + "-" + rlc.getName())
                            || addedLinks.contains(dstId + "-" + srcId + "-" + rlc.getName())) {
                        continue;
                    }

                    boolean createLink = true;
                    for (String attr : rlc.getAttributes()) {
                        int numSharedValues = getCommonPropertyValueCount(table, _clusteringResult.get(srcId),
                                _clusteringResult.get(dstId), attr, ignoreValues.get(attr));
                        if (numSharedValues <= 0) {
                            createLink = false;
                            break;
                        }
                    }
                    if (createLink) {
                        List<ClusterLink> srcAdjList = links.get(srcId);
                        if (srcAdjList == null) {
                            srcAdjList = new ArrayList<ClusterLink>();
                        }
                        ClusterLink edge = new ClusterLink();
                        edge.linkedClusterId = dstId;
                        edge.clusterLinkAttribute = rlc.getName();
                        edge.weight = normalizedWeights.get(rlc);

                        addedLinks.add(srcId + "-" + dstId + "-" + rlc.getName());

                        srcAdjList.add(edge);
                        links.put(srcId, srcAdjList);
                    }
                }
            }
        }
        return links;
    }

    public Map<String, List<ClusterLink>> getConnectivity(Iterable<String> clusters, String datasetName,
            ArrayList<RestLinkCriteria> linkCriteria) {
        return getConnectivity(clusters, clusters, datasetName, linkCriteria, false);
    }

    public List<DataRow> getClusterDetails(String clusterId) {
        List<DataRow> results = new ArrayList<DataRow>();

        Set<String> members = _clusteringResult.get(clusterId);
        DataTable table = TableDB.getInstance().getDataTableMembers(_datasetName, members);

        for (String memberId : members) {
            DataRow row = table.getRowById(memberId);
            if (row != null) {

                // Add any user tags to the row
                ArrayList<String> tags = TableDB.getInstance().getTags(memberId);
                String tagString = "";
                if (tags != null && tags.size() > 0) {
                    for (int i = 0; i < tags.size() - 1; i++) {
                        tagString += tags.get(i) + ',';
                    }
                    tagString += tags.get(tags.size() - 1);
                }
                row.put("tags", tagString);
                results.add(row);
            }
        }

        return results;
    }

    public Set<String> getCluster(String clusterId) {
        return _clusteringResult.get(clusterId);
    }

    public String getClusterName(DataTable table, String clusterId) {
        // Grab any example of highest weighted feature
        Set<String> members = _clusteringResult.get(clusterId);
        String key = _clusterParameters.get(0).getFirst(); // List is sorted by weight
        for (String memberId : members) {
            DataRow row = table.getRowById(memberId);
            String value = row.get(key);
            if (value != null) {
                return value;
            } else {
                return "";
            }
        }
        return "EMPTY";
    }

    public String getClusterName(DenseDataTable table, String clusterId) {
        // Grab any example of highest weighted feature
        Set<String> members = _clusteringResult.get(clusterId);
        String key = _clusterParameters.get(0).getFirst(); // List is sorted by weight
        int keyIdx = table.columns.indexOf(key);
        for (String memberId : members) {
            String[] row = table.getRowById(memberId);
            String value = row[keyIdx];
            if (value != null) {
                return value;
            } else {
                return "";
            }
        }
        return "EMPTY";
    }

    public String getClusterParametersString() {
        return _clusterParamtersString;
    }

    public static ClusterResults clusterTableWithML(String datasetName, ArrayList<String> clusterBy,
            ArrayList<Double> weights, String paramsKey) throws InterruptedException {
        DataTable dt = TableDB.getInstance().getDataTableColumns(datasetName, clusterBy);
        EntityClustering clusterer = new EntityClustering(dt.rows, clusterBy, weights);
        ClusterResult mlClusterResult = clusterer.cluster(datasetName);

        List<Pair<String, Double>> pairedParams = new ArrayList<Pair<String, Double>>();
        for (int i = 0; i < clusterBy.size(); i++) {
            pairedParams.add(new Pair<String, Double>(clusterBy.get(i), weights.get(i)));
        }

        return new ClusterResults(datasetName, mlClusterResult, pairedParams, paramsKey);
    }
}