com.grantingersoll.intell.clustering.KMeansClusteringEngine.java Source code

Java tutorial

Introduction

Here is the source code for com.grantingersoll.intell.clustering.KMeansClusteringEngine.java

Source

package com.grantingersoll.intell.clustering;

import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.utils.vectors.TermInfo;
import org.apache.mahout.utils.vectors.io.JWriterTermInfoWriter;
import org.apache.mahout.utils.vectors.io.SequenceFileVectorWriter;
import org.apache.mahout.utils.vectors.io.VectorWriter;
import org.apache.mahout.utils.vectors.lucene.CachedTermInfo;
import org.apache.mahout.utils.vectors.lucene.LuceneIterable;
import org.apache.mahout.utils.vectors.lucene.TFDFMapper;
import org.apache.mahout.vectorizer.TFIDF;
import org.apache.mahout.vectorizer.Weight;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.core.SolrEventListener;
import org.apache.solr.handler.clustering.DocumentClusteringEngine;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.DocSet;
import org.apache.solr.search.SolrIndexReader;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.RefCounted;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * Cluster the whole Lucene index.  The {@link org.apache.solr.core.SolrEventListener} side of this implementation
 * is responsible for getting the indexed clustered.  The {@link org.apache.solr.handler.clustering.DocumentClusteringEngine} side
 * of this is responsible.
 * <p/>
 * Eventually, this should evolve to just send the clustering job off to Hadoop, but for now we'll just do everything local
 */
public class KMeansClusteringEngine extends DocumentClusteringEngine
        implements SolrEventListener, KMeansClusteringParams {

    protected File clusterBaseDir;
    protected String inputField;
    protected SolrCore core;

    private DistanceMeasure measure;
    private double convergence = 0.001;
    private int maxIters = 20;
    private int k = 10;
    //private Object swapContext = new Object();
    private ExecutorService execService;
    protected Future<ClusterJob> theFuture;
    private ClusterJob lastSuccessful;
    private boolean cacheClusters = true;//lazy load the data structures representing the clusters
    private boolean cachePoints = true;

    public KMeansClusteringEngine() {
        execService = Executors.newSingleThreadExecutor();
        measure = new CosineDistanceMeasure();
    }

    @Override
    public NamedList cluster(SolrParams params) {
        NamedList result = new NamedList();
        //check to see if we have new results
        try {
            if (theFuture != null) {
                //see if we have new results, but don't wait too long for them
                ClusterJob job = theFuture.get(1, TimeUnit.MILLISECONDS);
                if (lastSuccessful != null) {
                    //clean up the old ones
                    //TODO: clean up the old dirs before switching lastSuccessful
                }
                lastSuccessful = job;
                theFuture = null;
            } else {

            }

        } catch (InterruptedException e) {
            log.error("Exception", e);
        } catch (ExecutionException e) {
            log.error("Exception", e);
        } catch (TimeoutException e) {
            log.error("Exception", e);
        }
        if (lastSuccessful != null) {//we have clusters
            //do we need the points?
            boolean includePoints = params.getBool(INCLUDE_POINTS, false);
            int clusterId = params.getInt(LIST_POINTS, Integer.MIN_VALUE);
            Map<Integer, List<String>> toPoints = lastSuccessful.clusterIdToPoints;
            String docId = params.get(IN_CLUSTER);
            if ((includePoints || clusterId != Integer.MIN_VALUE || docId != null) && toPoints == null) {
                //load the points
                try {
                    toPoints = readPoints(new Path(lastSuccessful.jobDir + File.separator + "points"),
                            lastSuccessful.conf);
                } catch (IOException e) {
                    throw new SolrException(SolrException.ErrorCode.SERVER_ERROR,
                            "Unable to load points: " + lastSuccessful);
                }
            }
            if (params.getBool(LIST_CLUSTERS)) {
                NamedList nl = new NamedList();
                result.add("all", nl);

                Map<Integer, Cluster> clusterMap = lastSuccessful.clusters;
                if (clusterMap == null) {
                    //we aren't caching, so load 'em up
                    try {
                        clusterMap = loadClusters(lastSuccessful);
                    } catch (Exception e) {
                        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR,
                                "unable to load the clusters from " + lastSuccessful);
                    }
                }

                for (Cluster cluster : clusterMap.values()) {
                    NamedList clusterNL = new NamedList();
                    nl.add(String.valueOf(cluster.getId()), clusterNL);
                    clusterNL.add("numPoints", cluster.getNumPoints());
                    //TODO: better format?
                    clusterNL.add("center", cluster.getCenter().asFormatString());
                    if (cluster.getRadius() != null) {
                        clusterNL.add("radius", cluster.getRadius().asFormatString());
                    }
                    if (includePoints) {
                        List<String> points = toPoints.get(cluster.getId());
                        clusterNL.add("points", points);
                    }
                }
            }

            if (docId != null) {

            }
            //TODO: support sending in multiple ids

            if (clusterId != Integer.MIN_VALUE) {
                List<String> points = lastSuccessful.clusterIdToPoints.get(clusterId);
                if (points != null) {
                    result.add(String.valueOf(clusterId), points);
                }
            }
        } else if (params.getBool(BUILD, false)) {
            RefCounted<SolrIndexSearcher> refCnt = core.getSearcher();
            int theK = params.getInt(K, 10);
            cluster(refCnt.get(), theK);
            refCnt.decref();
        }
        return result;
    }

    @Override
    public NamedList cluster(DocSet docSet, SolrParams solrParams) {
        NamedList result = null;
        //TODO: Schedule these docs for the future and return a key by which an application can pick them up later
        return result;
    }

    @Override
    public String init(NamedList config, SolrCore core) {
        String result = super.init(config, core);
        SolrParams params = SolrParams.toSolrParams(config);
        this.core = core;
        String dirStr = params.get("dir");
        clusterBaseDir = new File(dirStr);
        if (clusterBaseDir.isAbsolute() == false) {
            clusterBaseDir = new File(core.getDataDir(), dirStr);
        }
        clusterBaseDir.mkdirs();
        inputField = params.get("inputField");
        String distMeas = params.get("distanceMeasure");
        Class distClass = core.getResourceLoader().findClass(distMeas);

        try {
            measure = (DistanceMeasure) distClass.newInstance();
        } catch (InstantiationException e) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unable to load measure class", e);
        } catch (IllegalAccessException e) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unable to load measure class", e);
        }
        convergence = params.getDouble("convergence", 0.001);
        maxIters = params.getInt("maxIterations", 20);
        cacheClusters = params.getBool("cacheClusters", true);
        cachePoints = params.getBool("cachePoints", true);
        this.k = params.getInt("k");
        //See if we have clusters already
        File nowFile = new File(clusterBaseDir, "lastJob");
        if (nowFile.exists()) {
            lastSuccessful = readJobDetails(nowFile);

        }
        return result;
    }

    //Event Listener

    public void postCommit() {
        //nothing to do here, b/c we need the new searcher
    }

    public void newSearcher(SolrIndexSearcher newSearcher, SolrIndexSearcher currentSearcher) {
        cluster(newSearcher, k);
    }

    private void cluster(SolrIndexSearcher searcher, int k) {
        log.info("Clustering");
        //go and do the clustering.  First, we need to export the fields
        SchemaField keyField = searcher.getSchema().getUniqueKeyField();
        //TODO: should we prevent overlaps here if there are too many commits?  Clustering isn't something that has to be fresh all the time
        // and we likely can't sustain that anyway.
        if (keyField != null) {//we must have a key field
            //do this part synchronously here, and then spawn off a thread to do the clustering, otherwise it will take too long
            String idName = keyField.getName();
            Weight weight = new TFIDF();
            SolrIndexReader reader = searcher.getReader();
            try {
                TermInfo termInfo = new CachedTermInfo(reader, "content", 1, 100);
                LuceneIterable li = new LuceneIterable(reader, idName, inputField,
                        new TFDFMapper(reader, weight, termInfo));
                Date now = new Date();
                String jobDir = clusterBaseDir.getAbsolutePath() + File.separator + "clusters-" + now.getTime();
                log.info("Dumping {} to {}", inputField, clusterBaseDir);
                File outFile = new File(jobDir, "index-" + inputField + ".vec");
                VectorWriter vectorWriter = getSeqFileWriter(outFile.getAbsolutePath());
                long numDocs = vectorWriter.write(li, Integer.MAX_VALUE);
                vectorWriter.close();
                log.info("Wrote: {} vectors", numDocs);
                File dictOutFile = new File(jobDir, "dict-" + inputField + ".txt");
                log.info("Dictionary Output file: {}", dictOutFile);
                BufferedWriter writer = new BufferedWriter(
                        new OutputStreamWriter(new FileOutputStream(dictOutFile), Charset.forName("UTF8")));
                JWriterTermInfoWriter tiWriter = new JWriterTermInfoWriter(writer, "\t", inputField);
                tiWriter.write(termInfo);
                tiWriter.close();
                writer.close();
                //OK, the dictionary is dumped, now we can cluster, do this via a thread in the background.
                //when it's done, we can switch to it
                ClusterJob clusterJob = new ClusterJob(k, jobDir, new Path(outFile.getAbsolutePath()),
                        new Path(jobDir + File.separator + "clusters"),
                        new Path(jobDir + File.separator + "output"), new Path(dictOutFile.getAbsolutePath()));

                writeJobDetails(clusterJob);
                theFuture = execService.submit(new ClusterCallable(clusterJob));
            } catch (IOException e) {
                log.error("Exception", e);
            }
        }

    }

    private void writeJobDetails(ClusterJob clusterJob) throws IOException {
        File nowFile = new File(clusterBaseDir, "lastJob");
        Properties props = new Properties();
        props.put("jobDir", clusterJob.jobDir);
        props.put("k", clusterJob.k);
        props.put("clustersIn", clusterJob.clustersIn);
        props.put("output", clusterJob.output);
        props.put("input", clusterJob.input);
        props.put("dictionary", clusterJob.dictionary);
        FileWriter fWriter = new FileWriter(nowFile);
        props.store(fWriter, "lastJob");
        fWriter.close();
    }

    private ClusterJob readJobDetails(File jobFile) {
        log.info("Reading job from: {} ", jobFile);
        ClusterJob result = null;
        try {
            FileReader fReader = new FileReader(jobFile);
            Properties props = new Properties();
            props.load(fReader);
            result = new ClusterJob(Integer.parseInt(props.get("k").toString()), props.get("jobDir").toString(),
                    new Path(props.get("input").toString()), new Path(props.get("clustersIn").toString()),
                    new Path(props.get("output").toString()), new Path(props.get("dictionary").toString()));
        } catch (FileNotFoundException e) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unable to load: " + jobFile, e);
        } catch (IOException e) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unable to read: " + jobFile, e);
        }
        log.info("Read job: {}", result);
        return result;
    }

    private class ClusterJob {
        String jobDir;
        Path input, clustersIn, output, dictionary;
        Configuration conf;
        Map<Integer, Cluster> clusters;
        Map<Integer, List<String>> clusterIdToPoints;
        int k;

        private ClusterJob(int k, String jobDir, Path input, Path clustersIn, Path output, Path dictionary) {
            this.k = k;
            this.jobDir = jobDir;
            this.input = input;
            this.clustersIn = clustersIn;
            this.output = output;
            this.dictionary = dictionary;
            clusters = new HashMap<Integer, Cluster>();
        }

        @Override
        public String toString() {
            return "ClusterJob{" + "jobDir='" + jobDir + '\'' + ", input=" + input + ", clustersIn=" + clustersIn
                    + ", output=" + output + ", dictionary=" + dictionary + ", k=" + k + '}';
        }
    }

    private class ClusterCallable implements Callable<ClusterJob> {
        ClusterJob job;

        private ClusterCallable(ClusterJob job) {
            this.job = job;
            job.conf = new Configuration();
        }

        public ClusterJob call() throws Exception {
            log.info("Randomly seeding {} vectors from the input: {}", job.k, job.input);
            job.clustersIn = RandomSeedGenerator.buildRandom(job.input, job.clustersIn, job.k, measure);
            log.info("KMeansDriver.run: " + job);
            KMeansDriver.run(job.conf, job.input, job.clustersIn, job.output, measure, convergence, maxIters, true,
                    true);
            //job is done, should we build data structure now, in the background or wait until requested
            if (cacheClusters == false) {
                job.clusters = loadClusters(job);
            }
            if (cachePoints == false) {
                job.clusterIdToPoints = readPoints(new Path(job.jobDir + File.separator + "points"), job.conf);
            }
            log.info("Finished KMeansDriver.run: " + job);
            return job;
        }
    }

    private static Map<Integer, Cluster> loadClusters(ClusterJob job) throws Exception {
        Map<Integer, Cluster> result = new HashMap<Integer, Cluster>();
        try {

            FileSystem fs = job.output.getFileSystem(job.conf);
            for (FileStatus seqFile : fs.globStatus(new Path(job.output, "part-*"))) {
                Path path = seqFile.getPath();
                //System.out.println("Input Path: " + path); doesn't this interfere with output?
                SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job.conf);
                try {
                    Writable key = reader.getKeyClass().asSubclass(Writable.class).newInstance();
                    Writable value = reader.getValueClass().asSubclass(Writable.class).newInstance();
                    while (reader.next(key, value)) {
                        Cluster cluster = (Cluster) value;
                        result.put(cluster.getId(), cluster);
                    }
                } finally {
                    reader.close();
                }
            }
        } finally {

        }
        return result;
    }

    private static Map<Integer, List<String>> readPoints(Path pointsPathDir, Configuration conf)
            throws IOException {
        Map<Integer, List<String>> result = new TreeMap<Integer, List<String>>();

        FileSystem fs = pointsPathDir.getFileSystem(conf);
        FileStatus[] children = fs.listStatus(pointsPathDir, new PathFilter() {
            public boolean accept(Path path) {
                String name = path.getName();
                return !(name.endsWith(".crc") || name.startsWith("_"));
            }
        });

        for (FileStatus file : children) {
            Path path = file.getPath();
            SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
            try {
                IntWritable key = reader.getKeyClass().asSubclass(IntWritable.class).newInstance();
                WeightedVectorWritable value = reader.getValueClass().asSubclass(WeightedVectorWritable.class)
                        .newInstance();
                while (reader.next(key, value)) {
                    //key is the clusterId, value is a list of points
                    //String clusterId = value.toString();
                    List<String> pointList = result.get(key.get());
                    if (pointList == null) {
                        pointList = new ArrayList<String>();
                        result.put(key.get(), pointList);
                    }
                    //We know we are dealing with named vectors, b/c we generated from the id field
                    String name = ((NamedVector) value.getVector()).getName();
                    pointList.add(name);
                    //value = reader.getValueClass().asSubclass(WeightedVectorWritable.class).newInstance();
                }
            } catch (InstantiationException e) {
                log.error("Exception", e);
            } catch (IllegalAccessException e) {
                log.error("Exception", e);
            }
        }

        return result;
    }

    static class TermIndexWeight {
        private int index = -1;

        private final double weight;

        TermIndexWeight(int index, double weight) {
            this.index = index;
            this.weight = weight;
        }
    }

    //TODO: remove once MAHOUT-536 is committed

    public static String getTopFeatures(Vector vector, String[] dictionary, int numTerms) {

        List<TermIndexWeight> vectorTerms = new ArrayList<TermIndexWeight>();

        Iterator<Vector.Element> iter = vector.iterateNonZero();
        while (iter.hasNext()) {
            Vector.Element elt = iter.next();
            vectorTerms.add(new TermIndexWeight(elt.index(), elt.get()));
        }

        // Sort results in reverse order (ie weight in descending order)
        Collections.sort(vectorTerms, new Comparator<TermIndexWeight>() {

            public int compare(TermIndexWeight one, TermIndexWeight two) {
                return Double.compare(two.weight, one.weight);
            }
        });

        Collection<Pair<String, Double>> topTerms = new LinkedList<Pair<String, Double>>();

        for (int i = 0; (i < vectorTerms.size()) && (i < numTerms); i++) {
            int index = vectorTerms.get(i).index;
            String dictTerm = dictionary[index];
            if (dictTerm == null) {
                log.error("Dictionary entry missing for {}", index);
                continue;
            }
            topTerms.add(new Pair<String, Double>(dictTerm, vectorTerms.get(i).weight));
        }

        StringBuilder sb = new StringBuilder(100);

        for (Pair<String, Double> item : topTerms) {
            String term = item.getFirst();
            sb.append("\n\t\t");
            sb.append(StringUtils.rightPad(term, 40));
            sb.append("=>");
            sb.append(StringUtils.leftPad(item.getSecond().toString(), 20));
        }
        return sb.toString();
    }

    public void init(NamedList args) {
        //Defer all work to the clustering engine init
    }

    private static VectorWriter getSeqFileWriter(String outFile) throws IOException {
        Path path = new Path(outFile);
        Configuration conf = new Configuration();
        FileSystem fs = FileSystem.get(conf);
        // TODO: Make this parameter driven

        SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class,
                VectorWritable.class);

        return new SequenceFileVectorWriter(seqWriter);
    }

}