fall2015.b565.wisBreastCancer.KMeans.java Source code

Java tutorial

Introduction

Here is the source code for fall2015.b565.wisBreastCancer.KMeans.java

Source

/*
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
*/

package fall2015.b565.wisBreastCancer;

import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import fall2015.b565.wisBreastCancer.utils.Constants;
import fall2015.b565.wisBreastCancer.utils.PropertyReader;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.*;

public class KMeans {
    private static final Logger logger = LoggerFactory.getLogger(KMeans.class);
    private static int k;
    private static int l;
    private static double threshold = .1;
    private static int iterations = 10;
    private static int[] allAttributeHeaders = { 0, 1, 2, 3, 4, 5, 6, 7, 8 };
    private static int totalAttributes = 9;
    private static int vfoldBlocks = 0;

    public KMeans() {
        try {
            readConfigurations();
        } catch (Exception e) {
            logger.error("Error occurred while reading configuration files !!!");
        }
    }

    public List<PPV> findKmeansToAttributePowerSet(String cleanedfilePath) throws Exception {
        Set<Set<Integer>> kMeansPowerSet = getPowerSet(new HashSet<Integer>(Ints.asList(allAttributeHeaders)));
        List<PPV> ppvList = new ArrayList<PPV>();
        List<Record> records = getRecords(cleanedfilePath);
        for (Set<Integer> set : kMeansPowerSet) {
            if (set.size() != 0) {
                // changing k
                for (int i = 2; i < 20; i += 2) {
                    PPV ppvInstance = new PPV();
                    KMeansResult result = calculate(records, set, i, l);
                    double ppv = calculatePPV(result.getFinalCentroids(), result.getInitialRecords());
                    ppvInstance.setPpv(ppv);
                    ppvInstance.setAttributeHeaders(Ints.toArray(set));
                    ppvInstance.setAssociateKVal(i);
                    ppvList.add(ppvInstance);
                }
            }
        }
        Double maxPPV = 0.0;
        int associateKVal = 2;
        int[] attributeHeaders = null;
        String attributeHeader = "";
        for (PPV ppv : ppvList) {
            double ppvVal = ppv.getPpv();
            if (ppvVal > maxPPV) {
                maxPPV = ppvVal;
                associateKVal = ppv.getAssociateKVal();
                attributeHeaders = ppv.getAttributeHeaders();
            }
        }
        if (attributeHeaders != null) {
            for (int header : attributeHeaders) {
                attributeHeader += header + ",";
            }
        }
        System.out.println("Max PPV : " + maxPPV + " comes when Attribute Set is : " + attributeHeader
                + " and when k = " + associateKVal);
        return ppvList;
    }

    public KMeansResult findKmeansToAllAttributes(String cleanedFilePath) throws Exception {
        List<Record> records = getRecords(cleanedFilePath);
        // make data structure of attributes
        HashSet<Integer> attributes = new HashSet<Integer>(Ints.asList(allAttributeHeaders));
        return calculate(records, attributes, k, l);
    }

    public void findAttributeCorrelations() throws Exception {
        List<Record> records = getRecords(Constants.REPLACED_DATA_FILE_PATH);
        Map<Integer, List<Integer>> attributes = getAttributes(records);
        Map<int[], Double> corelationMap = findCorrelation(attributes);
        double max = 0;
        int[] maxPair = { 0, 0 };
        for (int[] pair : corelationMap.keySet()) {
            double correlation = corelationMap.get(pair);
            if (max < correlation) {
                max = correlation;
                maxPair = pair;
            }
            System.out.println("Attribute pair : {" + pair[0] + "," + pair[1] + "} : correlation : " + correlation);
        }
        System.out.println("Highest Correlation : " + max + " comes for attribute pair : {" + maxPair[0] + ", "
                + maxPair[1] + "}");

    }

    public Map<int[], Double> findCorrelation(Map<Integer, List<Integer>> attributes) {
        Set<Integer> indexes = attributes.keySet();
        Map<int[], Double> correlationMap = new HashMap<int[], Double>();
        List<int[]> possiblePairs = getPossiblePairs(indexes);
        for (int[] pair : possiblePairs) {
            List<Integer> attribute1 = attributes.get(pair[0]);
            List<Integer> attribute2 = attributes.get(pair[1]);
            double[] attributeArray1 = Doubles.toArray(attribute1);
            double[] attributeArray2 = Doubles.toArray(attribute2);
            PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation();
            double correlation = pearsonsCorrelation.correlation(attributeArray1, attributeArray2);
            correlationMap.put(pair, correlation);
        }
        return correlationMap;
    }

    private List<int[]> getPossiblePairs(Set<Integer> indexes) {
        List<int[]> pairs = new ArrayList<int[]>();
        for (int i = 1; i < indexes.size() + 1; i++) {
            for (int j = i + 1; j < indexes.size() + 1; j++) {
                if (i != j) {
                    int[] indexPair = { i, j };
                    pairs.add(indexPair);
                }
            }
        }
        return pairs;
    }

    public KMeansResult calculate(List<Record> records, Set<Integer> attributes, int k, int l) {
        try {
            KMeansResult result = new KMeansResult();
            // randomly select K centroids
            List<Centroid> centroids = generateCentroids(k);
            List<Centroid> randomCentroids = generateCentroids(k);

            double previousCentroidDifference = centroidDistance(centroids, randomCentroids, attributes);
            //            double previousCentroidDifference = Double.MIN_VALUE;
            double currentCentroidDifference = 0;
            // calculate distance for each data point
            // assign data points to centroids
            List<Centroid> lastAssignedCentroids = centroids;
            for (int i = 0; i < iterations; i++) {
                assignRecordsToCentroids(centroids, records, attributes, l);
                lastAssignedCentroids = centroids;
                // calculate average of the cluster and make them as new data point
                // do the same again until threshold satisfies
                List<Centroid> newCentroids = calculateAvgCentroid(centroids, records);

                // break if the centroids doesn't change
                currentCentroidDifference = centroidDistance(newCentroids, centroids, attributes);
                double difference = Math.abs(currentCentroidDifference - previousCentroidDifference);
                if (difference <= threshold * previousCentroidDifference) {
                    //                    System.out.println("KMeans finished in iterations: " + i + " with threshold difference: " + difference);
                    break;
                }
                previousCentroidDifference = currentCentroidDifference;
                centroids = newCentroids;
            }
            for (Centroid centroid : lastAssignedCentroids) {
                centroid.getAssignedRecords().clear();
            }
            assignRecordsToCentroids(lastAssignedCentroids, records, attributes);
            result.setFinalCentroids(lastAssignedCentroids);
            result.setInitialRecords(records);
            printClasses(lastAssignedCentroids, records);
            return result;
        } catch (Exception e) {
            String s = "Error occurred while calculating KMeans algorithm";
            logger.error(s, e);
            throw new RuntimeException(s, e);
        }
    }

    public double centroidDistance(List<Centroid> centroids1, List<Centroid> centroids2,
            Set<Integer> attributeIndexes) {
        double sum = 0;
        for (int i = 0; i < centroids1.size(); i++) {
            Centroid c1 = centroids1.get(i);
            Centroid c2 = centroids2.get(i);

            double distance = euclideanDistance(c1.getRandomRecord().getAttributes(),
                    c2.getRandomRecord().getAttributes(), attributeIndexes);
            sum += distance;
        }
        return sum;
    }

    public Set<Set<Integer>> getPowerSet(Set<Integer> originalSet) {
        Set<Set<Integer>> sets = new HashSet<Set<Integer>>();
        if (originalSet.isEmpty()) {
            sets.add(new HashSet<Integer>());
            return sets;
        }
        List<Integer> list = new ArrayList<Integer>(originalSet);
        Integer head = list.get(0);
        Set<Integer> rest = new HashSet<Integer>(list.subList(1, list.size()));
        for (Set<Integer> set : getPowerSet(rest)) {
            Set<Integer> newSet = new HashSet<Integer>();
            newSet.add(head);
            newSet.addAll(set);
            sets.add(newSet);
            sets.add(set);
        }
        return sets;
    }

    public double calculatePPV(List<Centroid> centroids, List<Record> records) {
        HashMap<Integer, HashMap<Integer, Integer>> centroidClassMap = new HashMap<Integer, HashMap<Integer, Integer>>();
        for (Centroid centroid : centroids) {
            List<Integer> assignedRecords = centroid.getAssignedRecords();
            HashMap<Integer, Integer> classCountMap = new HashMap<Integer, Integer>();
            for (Integer assignedRecord : assignedRecords) {
                Record record = records.get(assignedRecord);
                int dataClass = record.getDataClass();
                Integer count = classCountMap.get(dataClass);
                if (count == null) {
                    count = 0;
                }
                count++;
                classCountMap.put(dataClass, count);
            }
            centroidClassMap.put(centroid.getCentroidId(), classCountMap);
        }
        Integer totalTP = 0;
        Integer totalFP = 0;
        for (Integer centroidId : centroidClassMap.keySet()) {
            HashMap<Integer, Integer> classCountMap = centroidClassMap.get(centroidId);
            Integer maxCount = 0;
            Integer totalCount = 0;
            for (Integer classNumber : classCountMap.keySet()) {
                Integer class1Count = classCountMap.get(classNumber);
                totalCount += class1Count;
                if (class1Count > maxCount) {
                    maxCount = class1Count;
                }
            }
            totalTP += maxCount;
            totalFP += (totalCount - maxCount);
        }
        double ppval = (double) totalTP / (totalTP + totalFP);
        //        System.out.println("Total TP : " + totalTP + " TotalTP + TotalFP : " + (totalTP + totalFP) + " PPV : " + ppval);
        return ppval;

    }

    public void printClasses(List<Centroid> centroids, List<Record> records) {
        for (int i = 0; i < centroids.size(); i++) {
            System.out.println("Centroid: " + i + " has " + centroids.get(i).getAssignedRecords().size());
        }
        for (int i = 0; i < records.size(); i++) {
            //            System.out.println("Record: " + records.get(i).getScn() + " c = " + records.get(i).getDataClass() + " cent = " + records.get(i).getCentroidDistances() + " record attribute length: " + records.get(i).getAttributes().length);
        }
    }

    public void readConfigurations() throws Exception {
        try {
            k = Integer.valueOf(PropertyReader.getProperty(Constants.NUMBER_OF_CENTROIDS));
            if (k == 0) {
                logger.info(
                        "Number of centroids are not defined in the properties file. Hence using default value: "
                                + Constants.DEFAULT_NUMBER_OF_CENTROIDS);
                k = Constants.DEFAULT_NUMBER_OF_CENTROIDS;
            }
            l = Integer.valueOf(PropertyReader.getProperty(Constants.NUMBER_OF_CENTROIDS_DATA_ASSIGNED));
            //            if (l == 0 || l > k){
            //                logger.info("L is not defined in the properties file or defined L is larger than K. Hence using default value: " + Constants.DEFAULT_L);
            //                l = Constants.DEFAULT_L;
            //            }
            threshold = Double.valueOf(PropertyReader.getProperty(Constants.THRESHHOLD));
            iterations = Integer.parseInt(PropertyReader.getProperty(Constants.ITERATIONS));
            vfoldBlocks = Integer.parseInt(PropertyReader.getProperty(Constants.VFOLD_VALIDATION_BLOCKS));

        } catch (Exception e) {
            logger.error("Error occurred while reading configuration file", e);
            throw new Exception("Error occurred while reading configuration file", e);
        }
    }

    public List<Record> getRecords(String filePath) throws Exception {
        List<Record> recordList = new ArrayList<Record>();
        try {
            InputStream is = new FileInputStream(filePath);
            // Construct BufferedReader from FileReader
            BufferedReader br = new BufferedReader(new InputStreamReader(is));

            String line = null;
            while ((line = br.readLine()) != null) {
                String[] lineSplits = line.split(",");
                if (lineSplits.length > totalAttributes) {
                    Record record = new Record(Integer.valueOf(lineSplits[0]), totalAttributes);
                    for (int i = 1; i < totalAttributes - 1; i++) {
                        record.setAttribute(i - 1, Integer.parseInt(lineSplits[i]));
                    }
                    record.setDataClass(Integer.valueOf(lineSplits[10]));
                    recordList.add(record);
                }
            }
            br.close();
            return recordList;
        } catch (Exception e) {
            logger.error("Error occurred while reading cleaned data file", e);
            throw new Exception("Error occurred while reading cleaned data file", e);
        }
    }

    public Map<Integer, List<Integer>> getAttributes(List<Record> records) {
        Map<Integer, List<Integer>> attributeMap = new HashMap<Integer, List<Integer>>();
        List<Integer> attriA1 = new ArrayList<Integer>();
        List<Integer> attriA2 = new ArrayList<Integer>();
        List<Integer> attriA3 = new ArrayList<Integer>();
        List<Integer> attriA4 = new ArrayList<Integer>();
        List<Integer> attriA5 = new ArrayList<Integer>();
        List<Integer> attriA6 = new ArrayList<Integer>();
        List<Integer> attriA7 = new ArrayList<Integer>();
        List<Integer> attriA8 = new ArrayList<Integer>();
        List<Integer> attriA9 = new ArrayList<Integer>();
        for (Record record : records) {
            int[] attributes = record.getAttributes();
            attriA1.add(attributes[0]);
            attriA2.add(attributes[1]);
            attriA3.add(attributes[2]);
            attriA4.add(attributes[3]);
            attriA5.add(attributes[4]);
            attriA6.add(attributes[5]);
            attriA7.add(attributes[6]);
            attriA8.add(attributes[7]);
            attriA9.add(attributes[8]);
        }
        attributeMap.put(1, attriA1);
        attributeMap.put(2, attriA2);
        attributeMap.put(3, attriA3);
        attributeMap.put(4, attriA4);
        attributeMap.put(5, attriA5);
        attributeMap.put(6, attriA6);
        attributeMap.put(7, attriA7);
        attributeMap.put(8, attriA8);
        attributeMap.put(9, attriA9);
        return attributeMap;
    }

    public double euclideanDistance(int[] attrib1, int[] attrib2, Set<Integer> attributeIndexes) {
        double v = 0;
        for (int i : attributeIndexes) {
            v += Math.pow(((Integer) attrib1[i] - (Integer) attrib2[i]), 2);
        }
        return Math.sqrt(v);
    }

    public List<Centroid> generateCentroids(int k) throws Exception {
        // read number of centroids from properties file
        List<Centroid> randomCentroids = new ArrayList<Centroid>();
        Random random = new Random();
        for (int j = 0; j < k; j++) {
            Record record = new Record(0, totalAttributes);
            for (int i = 0; i < totalAttributes; i++) {
                int r = (int) (random.nextDouble() * 10);
                record.setAttribute(i, r);
            }
            Centroid centroid = new Centroid(record, j);
            randomCentroids.add(centroid);
        }
        return randomCentroids;
    }

    public void assignRecordsToCentroids(List<Centroid> centroidList, List<Record> recordList,
            Set<Integer> attributeIndexes) {
        for (int j = 0; j < recordList.size(); j++) {
            double[] distancesToEachCentroid = new double[centroidList.size()];
            Record record = recordList.get(j);
            int[] attributeList1 = record.getAttributes();
            for (int i = 0; i < centroidList.size(); i++) {
                Centroid centroid = centroidList.get(i);
                Record randomRecord = centroid.getRandomRecord();
                int[] attributeList2 = randomRecord.getAttributes();
                distancesToEachCentroid[i] = euclideanDistance(attributeList1, attributeList2, attributeIndexes);
            }
            Map<Integer, Double> indexDistanceMap = calculateMin(distancesToEachCentroid);
            for (Integer index : indexDistanceMap.keySet()) {
                centroidList.get(index).addRecordToAssignedList(j);
                CentroidDistance centroidDistance = new CentroidDistance(index, indexDistanceMap.get(index));
                record.addCentroidDistance(centroidDistance);
            }
        }
    }

    // This need to be changed according to l
    public void assignRecordsToCentroids(List<Centroid> centroidList, List<Record> recordList,
            Set<Integer> attributeIndexes, int l) {
        for (int j = 0; j < recordList.size(); j++) {
            List<CentroidDistance> distancesToEachCentroid = new ArrayList<CentroidDistance>();
            Record record = recordList.get(j);
            record.getCentroidDistances().clear();
            int[] attributeList1 = record.getAttributes();
            for (int i = 0; i < centroidList.size(); i++) {
                Centroid centroid = centroidList.get(i);
                Record randomRecord = centroid.getRandomRecord();
                int[] attributeList2 = randomRecord.getAttributes();
                distancesToEachCentroid.add(new CentroidDistance(i,
                        euclideanDistance(attributeList1, attributeList2, attributeIndexes)));
            }
            Collections.sort(distancesToEachCentroid);
            for (int i = 0; i < l; i++) {
                CentroidDistance centroidDistance = distancesToEachCentroid.get(i);
                int index = centroidDistance.centroidIndex;
                centroidList.get(index).addRecordToAssignedList(j);
                record.addCentroidDistance(centroidDistance);
            }
        }
    }

    public Map<Integer, Double> calculateMin(double[] distancesToEachCentroid) {
        Map<Integer, Double> distanceIndexMap = new HashMap<Integer, Double>();
        double min = distancesToEachCentroid[0];
        int index = 0;
        for (int k = 1; k < distancesToEachCentroid.length; k++) {
            if (distancesToEachCentroid[k] < min) {
                min = distancesToEachCentroid[k];
                index = k;
            }
        }
        distanceIndexMap.put(index, min);
        return distanceIndexMap;
    }

    public List<Centroid> calculateAvgCentroid(List<Centroid> centroidList, List<Record> recordList) {
        List<Centroid> newCentroidList = new ArrayList<Centroid>();

        for (int i = 0; i < centroidList.size(); i++) {
            Centroid centroid = new Centroid(new Record(0, totalAttributes), i);
            newCentroidList.add(centroid);
        }

        double[] sum = new double[newCentroidList.size()];

        for (Record r : recordList) {
            double totalDistance = 0;
            List<CentroidDistance> centroidDistances = r.getCentroidDistances();
            for (CentroidDistance distance : centroidDistances) {
                totalDistance += 1 / distance.getDistance();
            }
            for (CentroidDistance centroidDistance : centroidDistances) {
                Centroid c = newCentroidList.get(centroidDistance.getCentroidIndex());
                Record centroidRecord = c.getRandomRecord();
                double probability = (1 / centroidDistance.getDistance()) / totalDistance;

                int[] attrbs = centroidRecord.getAttributes();
                for (int j = 0; j < totalAttributes; j++) {
                    int d = r.getAttributes()[j];
                    // System.out.println(probability);
                    attrbs[j] = (int) (attrbs[j] + d * probability);
                }
                sum[centroidDistance.getCentroidIndex()] += probability;
            }
        }

        for (int i = 0; i < centroidList.size(); i++) {
            Centroid centroid = newCentroidList.get(i);
            Centroid oldCentroid = centroidList.get(i);
            if (oldCentroid.getAssignedRecords().size() != 0) {
                Record r = centroid.getRandomRecord();
                int[] attrbs = r.getAttributes();
                for (int j = 0; j < totalAttributes; j++) {
                    int d = r.getAttributes()[j];
                    attrbs[j] = (int) (d / sum[i]);
                }
            }
        }

        return newCentroidList;
    }

    private class VFoldRecords {
        List<Record> train = new ArrayList<Record>();
        List<Record> test = new ArrayList<Record>();
    }

    public double vFoldCrossValidation(List<Record> recordList, Set<Integer> attributeIndexes) {
        double ppv = 0;
        for (int i = 0; i < vfoldBlocks; i++) {
            VFoldRecords vFoldRecords = vFoldRecords(recordList, i, vfoldBlocks);
            KMeansResult result = calculate(vFoldRecords.train, attributeIndexes, k, l);

            List<Centroid> finalCentroids = result.getFinalCentroids();
            for (Centroid centroid : finalCentroids) {
                centroid.getAssignedRecords().clear();
            }
            assignRecordsToCentroids(finalCentroids, vFoldRecords.test, attributeIndexes);
            ppv += calculatePPV(finalCentroids, vFoldRecords.test);
        }
        return ppv / vfoldBlocks;
    }

    private VFoldRecords vFoldRecords(List<Record> recordList, int index, int folds) {
        VFoldRecords vFoldRecords = new VFoldRecords();
        int foldSize = recordList.size() / folds;
        int start = index * foldSize;
        int end = start + foldSize > recordList.size() ? recordList.size() - 1 : start + foldSize - 1;

        for (int i = 0; i < recordList.size(); i++) {
            if (i >= start && i <= end) {
                vFoldRecords.test.add(recordList.get(i));
            } else {
                vFoldRecords.train.add(recordList.get(i));
            }
        }
        return vFoldRecords;
    }

}