MPCKMeans.java Source code

Java tutorial

Introduction

Here is the source code for MPCKMeans.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    MPCKMeans.java
 *    Copyright (C) 2003 Sugato Basu and Misha Bilenko
 *
 */

import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.StringTokenizer;
import java.util.TreeMap;

import weka.clusterers.Clusterer;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.SparseInstance;
import weka.core.Tag;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.Utils;

/**
 * Pairwise constrained k means clustering class.
 *
 * Valid options are:<p>
 *
 * -N <number of clusters> <br>
 * Specify the number of clusters to generate. <p>
 *
 * -R <random seed> <br>
 * Specify random number seed <p>
 *
 * -M <metric-class> <br>
 * Specifies the name of the distance metric class that should be used
 * 
 * @author Sugato Basu(sugato@cs.utexas.edu) and Misha Bilenko (mbilenko@cs.utexas.edu)
 * @see Clusterer
 * @see OptionHandler
 */
public class MPCKMeans extends clusterers implements OptionHandler, SemiSupClusterer {

    /** Name of clusterer */
    String m_name = "MPCKMeans";

    /** holds the instances in the clusters */
    protected ArrayList m_Clusters = null;

    /** holds the instance indices in the clusters */
    protected HashSet[] m_IndexClusters = null;

    /** holds the ([instance pair] -> [type of constraint]) mapping,
        where the hashed value stores the type of link but the instance
        pair does not hold the type of constraint - it holds (instanceIdx1,
        instanceIdx2, DONT_CARE_LINK). This is done to make lookup easier
        in future 
    */
    protected HashMap m_ConstraintsHash = null;

    public HashMap getConstraintsHash() {
        return m_ConstraintsHash;
    }

    /** stores the ([instanceIdx] -> [ArrayList of constraints])
        mapping, where the arraylist contains the constraints in which
        instanceIdx is involved. Note that the instance pairs stored in
        the Arraylist have the actual link type.  
    */
    protected HashMap m_instanceConstraintHash = null;

    public HashMap getInstanceConstraintsHash() {
        return m_instanceConstraintHash;
    }

    public void setInstanceConstraintsHash(HashMap instanceConstraintHash) {
        m_instanceConstraintHash = instanceConstraintHash;
    }

    /** holds the points involved in the constraints */
    protected HashSet m_SeedHash = null;

    /** Access */
    public HashSet getSeedHash() {
        return m_SeedHash;
    }

    /** weight to be given to each constraint */
    protected double m_CLweight = 1;

    /** weight to be given to each constraint */
    protected double m_MLweight = 1;

    /** should constraints from transitive closure be added? */
    protected boolean m_useTransitiveConstraints = true;

    /** is it an offline metric (BarHillelMetric or XingMetric)? */
    protected boolean m_isOfflineMetric;

    public boolean getIsOfflineMetric() {
        return m_isOfflineMetric;
    }

    /** the maximum distance between cannot-link constraints */
    protected double m_MaxCannotLinkDistance = 0;

    /** the min similarity between cannot-link constraints */
    protected double m_MaxCannotLinkSimilarity = 0;

    /** the maximum distance between cannot-link constraints */
    protected double m_maxCLPenalties[] = null;
    public Instance m_maxCLPoints[][] = null;
    public Instance m_maxCLDiffInstances[] = null;

    /** verbose? */
    protected boolean m_verbose = false;

    /** distance Metric */
    protected LearnableMetric m_metric = new WeightedEuclidean();
    protected MPCKMeansMetricLearner m_metricLearner = new WEuclideanLearner();

    /** Individual metrics for each cluster can be used */
    protected boolean m_useMultipleMetrics = false;
    protected LearnableMetric[] m_metrics = null;
    protected MPCKMeansMetricLearner[] m_metricLearners = null;

    /** Relative importance of the log-term for the weights in the objective function */
    protected double m_logTermWeight = 0.01;

    /** Regularization for weights */
    protected boolean m_regularize = false;
    protected double m_regularizerTermWeight = 0.001;

    /** We will hash log terms to avoid recomputing every time TODO:  implement for Euclidean*/
    protected double[] m_logTerms = null;

    /** has the metric has been constructed?  a fix for multiple buildClusterer's */
    protected boolean m_metricBuilt = false;

    /** indicates whether instances are sparse */
    protected boolean m_isSparseInstance = false;

    /** Is the objective function increasing or decreasing?  Depends on type
     * of metric used:  for similarity-based metric, increasing, for distance-based - decreasing */
    protected boolean m_objFunDecreasing = true;

    /** Seedable or not (true by default) */
    protected boolean m_Seedable = true;

    /** Possible metric training */
    public static final int TRAINING_NONE = 1;
    public static final int TRAINING_EXTERNAL = 2;
    public static final int TRAINING_INTERNAL = 4;
    public static final Tag[] TAGS_TRAINING = { new Tag(TRAINING_NONE, "None"),
            new Tag(TRAINING_EXTERNAL, "External"), new Tag(TRAINING_INTERNAL, "Internal") };

    protected int m_Trainable = TRAINING_INTERNAL;

    /** keep track of the number of iterations completed before convergence
     */
    protected int m_Iterations = 0;

    /** number of constraint violations
     */
    protected int m_numViolations = 0;

    /** keep track of the number of iterations when no points were moved */
    protected int m_numBlankIterations = 0;

    /** the maximum number of iterations */
    protected int m_maxIterations = Integer.MAX_VALUE;

    /** the maximum number of iterations with no points moved */
    protected int m_maxBlankIterations = 20;

    /** min difference of objective function values for convergence*/
    protected double m_ObjFunConvergenceDifference = 1e-5;

    /** value of current objective function */
    protected double m_Objective = Double.MAX_VALUE;

    /** value of last objective function */
    protected double m_OldObjective;

    /** Variables to track components of the objective function */
    protected double m_objVariance;
    protected double m_objCannotLinks;
    protected double m_objMustLinks;
    protected double m_objNormalizer;
    protected double m_objRegularizer;
    /** Variable to track the contribution of the currently considered point */
    protected double m_objVarianceCurrPoint;
    protected double m_objCannotLinksCurrPoint;
    protected double m_objMustLinksCurrPoint;
    protected double m_objNormalizerCurrPoint;

    protected double m_objVarianceCurrPointBest;
    protected double m_objCannotLinksCurrPointBest;
    protected double m_objMustLinksCurrPointBest;
    protected double m_objNormalizerCurrPointBest;

    /** returns objective function */
    public double objectiveFunction() {
        return m_Objective;
    }

    /**
     * training instances with labels
     */
    protected Instances m_TotalTrainWithLabels;

    public Instances getTotalTrainWithLabels() {
        return m_TotalTrainWithLabels;
    }

    public void setTotalTrainWithLabels(Instances inst) {
        m_TotalTrainWithLabels = inst;
    }

    /**
     * training instances
     */
    protected Instances m_Instances;

    /** A hash where the instance checksums are hashed */
    protected HashMap m_checksumHash = null;
    protected double[] m_checksumCoeffs = null;

    /** test data -- required to make sure that test points are not
        selected during active learning */
    protected int m_StartingIndexOfTest = -1;

    /**
     * number of clusters to generate, default is -1 to get it from labeled data
     */
    protected int m_NumClusters = -1;

    /**
     * holds the cluster centroids
     */
    protected Instances m_ClusterCentroids;

    /** Accessor */
    public Instances getClusterCentroids() {
        return m_ClusterCentroids;
    }

    public void setClusterCentroids(Instances centroids) {
        m_ClusterCentroids = centroids;
    }

    /**
     * temporary variable holding cluster assignments while iterating
     */
    protected int[] m_ClusterAssignments;

    public int[] getClusterAssignments() {
        return m_ClusterAssignments;
    }

    public void setClusterAssignments(int[] clusterAssignments) {
        m_ClusterAssignments = clusterAssignments;
    }

    protected String m_ClusterAssignmentsOutputFile;

    public String getClusterAssignmentsOutputFile() {
        return m_ClusterAssignmentsOutputFile;
    }

    public void setClusterAssignmentsOutputFile(String file) {
        m_ClusterAssignmentsOutputFile = file;
    }

    protected String m_ConstraintIncoherenceFile;

    public String getConstraintIncoherenceFile() {
        return m_ConstraintIncoherenceFile;
    }

    public void setConstraintIncoherenceFile(String file) {
        m_ConstraintIncoherenceFile = file;
    }

    /**
     * holds the random Seed, useful for randomPerturbInit
     */
    protected int m_RandomSeed = 42;

    /**
     * holds the random number generator used in various parts of the code
     */
    protected Random m_RandomNumberGenerator = null;

    /** Define possible assignment strategies */
    protected MPCKMeansAssigner m_Assigner = new SimpleAssigner(this);

    /** Define possible initialization strategies */
    //  protected MPCKMeansInitializer m_Initializer = new RandomPerturbInitializer(this);
    protected MPCKMeansInitializer m_Initializer = new WeightedFFNeighborhoodInit(this);

    /** Access */
    public Random getRandomNumberGenerator() {
        return m_RandomNumberGenerator;
    }

    /* Constructor */
    public MPCKMeans() {
    }

    /* Constructor */
    public MPCKMeans(LearnableMetric metric) {
        m_metric = metric;
        m_objFunDecreasing = metric.isDistanceBased();
    }

    /**
     * We always want to implement SemiSupClusterer from a class extending Clusterer.  
     * We want to be able to return the underlying parent class.
     * @return parent Clusterer class
     */
    public Clusterer getThisClusterer() {
        return this;
    }

    /**
     * Cluster given instances to form the specified number of clusters.
     *
     * @param data instances to be clustered
     * @param numClusters number of clusters to create
     * @exception Exception if something goes wrong.
     */
    public void buildClusterer(Instances data, int numClusters) throws Exception {
        m_NumClusters = numClusters;
        System.out.println("Creating " + m_NumClusters + " clusters");
        m_Initializer.setNumClusters(m_NumClusters);

        if (data.instance(0) instanceof SparseInstance) {
            m_isSparseInstance = true;
        }
        buildClusterer(data);
    }

    /**
     * Generates the clustering using labeled seeds
     *
     * @param labeledData set of labeled instances to use as seeds
     * @param unlabeledData set of unlabeled instances
     * @param classIndex attribute index in labeledData which holds class info
     * @param numClusters number of clusters to create
     * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive, set to -1 if not relevant
     * @exception Exception if something is wrong
     */
    public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters,
            int startingIndexOfTest) throws Exception {
        // Dummy function
        throw new Exception(
                "Not implemented for MPCKMeans, only here for " + "compatibility to SemiSupClusterer interface");
    }

    /**
     * Clusters unlabeledData and labeledData (with labels removed),
     * using constraints in labeledPairs to initialize
     *
     * @param labeledPairs labeled pairs to be used to initialize
     * @param unlabeledData unlabeled instances
     * @param labeledData labeled instances
     * @param numClusters number of clusters
     * @param startingIndexOfTest starting index of test set in unlabeled data
     * @exception Exception if something goes wrong.  */
    public void buildClusterer(ArrayList labeledPairs, Instances unlabeledData, Instances labeledData,
            int numClusters, int startingIndexOfTest) throws Exception {
        m_TotalTrainWithLabels = labeledData;

        if (labeledPairs != null) {
            m_SeedHash = new HashSet((int) (unlabeledData.numInstances() / 0.75 + 10));
            m_ConstraintsHash = new HashMap();
            m_instanceConstraintHash = new HashMap();

            for (int i = 0; i < labeledPairs.size(); i++) {
                InstancePair pair = (InstancePair) labeledPairs.get(i);
                Integer firstInt = new Integer(pair.first);
                Integer secondInt = new Integer(pair.second);

                // for first point 
                if (!m_SeedHash.contains(firstInt)) { // add instances with constraints to seedHash
                    if (m_verbose) {
                        System.out.println("Adding " + firstInt + " to seedHash");
                    }
                    m_SeedHash.add(firstInt);
                }

                // for second point 
                if (!m_SeedHash.contains(secondInt)) {
                    m_SeedHash.add(secondInt);
                    if (m_verbose) {
                        System.out.println("Adding " + secondInt + " to seedHash");
                    }
                }
                if (pair.first >= pair.second) {
                    throw new Exception("Ordering reversed - something wrong!!");
                } else {
                    InstancePair newPair = null;
                    newPair = new InstancePair(pair.first, pair.second, InstancePair.DONT_CARE_LINK);
                    m_ConstraintsHash.put(newPair, new Integer(pair.linkType)); // WLOG first < second
                    if (m_verbose) {
                        System.out.println(
                                "Adding constraint (" + pair.first + "," + pair.second + "), " + pair.linkType);
                    }

                    // hash the constraints for the instances involved
                    Object constraintList1 = m_instanceConstraintHash.get(firstInt);
                    if (constraintList1 == null) {
                        ArrayList constraintList = new ArrayList();
                        constraintList.add(pair);
                        m_instanceConstraintHash.put(firstInt, constraintList);
                    } else {
                        ((ArrayList) constraintList1).add(pair);
                    }
                    Object constraintList2 = m_instanceConstraintHash.get(secondInt);
                    if (constraintList2 == null) {
                        ArrayList constraintList = new ArrayList();
                        constraintList.add(pair);
                        m_instanceConstraintHash.put(secondInt, constraintList);
                    } else {
                        ((ArrayList) constraintList2).add(pair);
                    }
                }
            }
        }

        m_StartingIndexOfTest = startingIndexOfTest;
        if (m_verbose) {
            System.out.println("Starting index of test: " + m_StartingIndexOfTest);
        }

        // learn metric using labeled data,
        // then cluster both the labeled and unlabeled data
        System.out.println("Initializing metric: " + m_metric);
        m_metric.buildMetric(unlabeledData);
        m_metricBuilt = true;
        m_metricLearner.setMetric(m_metric);
        m_metricLearner.setClusterer(this);

        // normalize all data for SPKMeans
        if (m_metric.doesNormalizeData()) {
            for (int i = 0; i < unlabeledData.numInstances(); i++) {
                m_metric.normalizeInstanceWeighted(unlabeledData.instance(i));
            }
        }

        // either create a new metric if multiple metrics,
        // or just point them all to m_metric
        m_metrics = new LearnableMetric[numClusters];
        m_metricLearners = new MPCKMeansMetricLearner[numClusters];
        for (int i = 0; i < m_metrics.length; i++) {
            if (m_useMultipleMetrics) {
                m_metrics[i] = (LearnableMetric) m_metric.clone();
                m_metricLearners[i] = (MPCKMeansMetricLearner) m_metricLearner.clone();
                m_metricLearners[i].setMetric(m_metrics[i]);
                m_metricLearners[i].setClusterer(this);
            } else {
                m_metrics[i] = m_metric;
                m_metricLearners[i] = m_metricLearner;
            }
        }
        buildClusterer(unlabeledData, numClusters);
    }

    /**
     * Generates a clusterer. Instances in data have to be
     * either all sparse or all non-sparse
     *
     * @param data set of instances serving as training data 
     * @exception Exception if the clusterer has not been 
     * generated successfully
     */
    public void buildClusterer(Instances data) throws Exception {
        System.out.println("ML weight=" + m_MLweight);
        System.out.println("CL weight= " + m_CLweight);
        System.out.println("LOG term weight=" + m_logTermWeight);
        System.out.println("Regularizer weight= " + m_regularizerTermWeight);
        m_RandomNumberGenerator = new Random(m_RandomSeed);

        if (m_metric instanceof OfflineLearnableMetric) {
            m_isOfflineMetric = true;
        } else {
            m_isOfflineMetric = false;
        }

        // Don't rebuild the metric if it was already trained
        if (!m_metricBuilt) {
            m_metric.buildMetric(data);
            m_metricBuilt = true;
            m_metricLearner.setMetric(m_metric);
            m_metricLearner.setClusterer(this);

            m_metrics = new LearnableMetric[m_NumClusters];
            m_metricLearners = new MPCKMeansMetricLearner[m_NumClusters];
            for (int i = 0; i < m_metrics.length; i++) {
                if (m_useMultipleMetrics) {
                    m_metrics[i] = (LearnableMetric) m_metric.clone();
                    m_metricLearners[i] = (MPCKMeansMetricLearner) m_metricLearner.clone();
                    m_metricLearners[i].setMetric(m_metrics[i]);
                    m_metricLearners[i].setClusterer(this);
                } else {
                    m_metrics[i] = m_metric;
                    m_metricLearners[i] = m_metricLearner;
                }
            }
        }

        setInstances(data);
        m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);
        m_ClusterAssignments = new int[m_Instances.numInstances()];

        if (m_Instances.checkForNominalAttributes() && m_Instances.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n");
        }

        m_ClusterCentroids = m_Initializer.initialize();

        // if all instances are smoothed by the metric, the centroids
        // need to be smoothed too (note that this is independent of
        // centroid smoothing performed by K-Means)
        if (m_metric instanceof InstanceConverter) {
            System.out.println("Converting centroids...");
            Instances convertedCentroids = new Instances(m_ClusterCentroids, m_NumClusters);
            for (int i = 0; i < m_ClusterCentroids.numInstances(); i++) {
                Instance centroid = m_ClusterCentroids.instance(i);
                convertedCentroids.add(((InstanceConverter) m_metric).convertInstance(centroid));
            }

            m_ClusterCentroids.delete();
            for (int i = 0; i < convertedCentroids.numInstances(); i++) {
                m_ClusterCentroids.add(convertedCentroids.instance(i));
            }
        }

        System.out.println("Done initializing clustering ...");
        getIndexClusters();

        if (m_verbose && m_Seedable) {
            printIndexClusters();
            for (int i = 0; i < m_NumClusters; i++) {
                System.out.println("Centroid " + i + ": " + m_ClusterCentroids.instance(i));
            }
        }

        // Some extra work for smoothing metrics
        if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) {

            SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric;
            Instances smoothedCentroids = new Instances(m_Instances, m_NumClusters);

            for (int i = 0; i < m_ClusterCentroids.numInstances(); i++) {
                Instance smoothedCentroid = smoothingMetric.smoothInstance(m_ClusterCentroids.instance(i));
                smoothedCentroids.add(smoothedCentroid);
            }
            m_ClusterCentroids = smoothedCentroids;

            updateSmoothingMetrics();
        }

        runKMeans();
    }

    protected void updateSmoothingMetrics() {
        if (m_useMultipleMetrics) {
            for (int i = 0; i < m_NumClusters; i++) {
                ((SmoothingMetric) m_metrics[i]).updateAlpha();
            }
        } else {
            ((SmoothingMetric) m_metric).updateAlpha();
        }
    }

    /**
     * Reset all values that have been learned
     */
    public void resetClusterer() throws Exception {
        m_metric.resetMetric();
        if (m_useMultipleMetrics) {
            for (int i = 0; i < m_metrics.length; i++) {
                m_metrics[i].resetMetric();
            }
        }

        m_SeedHash = null;
        m_ConstraintsHash = null;
        m_instanceConstraintHash = null;
    }

    /** Turn seeding on and off
     * @param seedable should seeding be done?
     */
    public void setSeedable(boolean seedable) {
        m_Seedable = seedable;
    }

    /** Turn metric learning on and off
     * @param trainable should metric learning be done?
     */
    public void setTrainable(SelectedTag trainable) {
        if (trainable.getTags() == TAGS_TRAINING) {
            if (m_verbose) {
                System.out.println("Trainable: " + trainable.getSelectedTag().getReadable());
            }
            m_Trainable = trainable.getSelectedTag().getID();
        }
    }

    /** Is seeding performed?
     * @return is seeding being done?
     */
    public boolean getSeedable() {
        return m_Seedable;
    }

    /** Is metric learning performed?
     * @return is metric learning being done?
     */
    public SelectedTag getTrainable() {
        return new SelectedTag(m_Trainable, TAGS_TRAINING);
    }

    /**
     * We can have clusterers that don't utilize seeding
     */
    public boolean seedable() {
        return m_Seedable;
    }

    /** Outputs the current clustering
     *
     * @exception Exception if something goes wrong
     */
    public void printIndexClusters() throws Exception {
        if (m_IndexClusters == null)
            throw new Exception("Clusters were not created");

        for (int i = 0; i < m_NumClusters; i++) {
            HashSet cluster = m_IndexClusters[i];
            if (cluster == null) {
                System.out.println("Cluster " + i + " is null");
            } else {
                System.out.println("Cluster " + i + " consists of " + cluster.size() + " elements");
                Iterator iter = cluster.iterator();
                while (iter.hasNext()) {
                    int idx = ((Integer) iter.next()).intValue();
                    Instance inst = m_TotalTrainWithLabels.instance(idx);
                    if (m_TotalTrainWithLabels.classIndex() >= 0) {
                        System.out
                                .println("\t\t" + idx + ":" + inst.classAttribute().value((int) inst.classValue()));
                    }
                }
            }
        }
    }

    /** E-step of the KMeans clustering algorithm -- find best cluster
     * assignments. Returns the number of points moved in this step 
     */
    protected int findBestAssignments() throws Exception {
        int moved = 0;
        double distance = 0;
        m_Objective = 0;
        m_objVariance = 0;
        m_objCannotLinks = 0;
        m_objMustLinks = 0;
        m_objNormalizer = 0;

        // Initialize the regularizer and normalizer hashes
        InitNormalizerRegularizer();

        if (m_isOfflineMetric) {
            moved = assignAllInstancesToClusters();
        } else {
            moved = assignPoints();
        }
        if (m_verbose) {
            System.out.println("  " + moved + " points moved in this E-step");
        }
        return moved;
    }

    /** Initialize m_logTerms and m_regularizerTerms */
    protected void InitNormalizerRegularizer() {
        m_logTerms = new double[m_NumClusters];
        m_objRegularizer = 0;

        if (m_useMultipleMetrics) {
            for (int i = 0; i < m_NumClusters; i++) {
                m_logTerms[i] = m_logTermWeight * m_metrics[i].getNormalizer();

                if (m_regularize) {
                    m_objRegularizer += m_regularizerTermWeight * m_metrics[i].regularizer();
                }
            }
        } else { // we fill the logTerms with the log(det) of the only weight matrix
            m_logTerms[0] = m_logTermWeight * m_metric.getNormalizer();
            for (int i = 1; i < m_logTerms.length; i++) {
                m_logTerms[i] = m_logTerms[0];
            }
            if (m_regularize) {
                m_objRegularizer = m_regularizerTermWeight * m_metric.regularizer();
            }
        }
    }

    /** Decides which assignment strategy to use based on argument passed in */
    int assignPoints() throws Exception {
        int moved = 0;

        moved = m_Assigner.assign();
        m_Objective = m_objVariance + m_objMustLinks + m_objCannotLinks + m_objNormalizer - m_objRegularizer;
        if (m_verbose) {
            System.out.println((float) m_Objective + " - Objective function (incomplete) after assignment");
            System.out.println("\tvar=" + ((float) m_objVariance) + "\tC=" + ((float) m_objCannotLinks) + "\tM="
                    + ((float) m_objMustLinks) + "\tLOG=" + ((float) m_objNormalizer) + "\tREG="
                    + ((float) m_objRegularizer));
        }
        // TODO:  add a m_fast switch and put the following line inside it.
        //    calculateObjectiveFunction();

        return moved;
    }

    /**
     * Classifies the instance using the current clustering, considering constraints
     *
     * @param instance the instance to be assigned to a cluster
     * @return the number of the assigned cluster as an integer if the
     * class is enumerated, otherwise the predicted value
     * @exception Exception if instance could not be classified
     * successfully 
     */
    public int assignInstanceToClusterWithConstraints(int instIdx) throws Exception {
        int bestCluster = 0;
        double lowestPenalty = Double.MAX_VALUE;
        int moved = 0;

        // try each cluster and find one with lowest penalty
        for (int i = 0; i < m_NumClusters; i++) {
            double penalty = penaltyForInstance(instIdx, i);

            if (penalty < lowestPenalty) {
                lowestPenalty = penalty;
                bestCluster = i;
                m_objVarianceCurrPointBest = m_objVarianceCurrPoint;
                m_objNormalizerCurrPointBest = m_objNormalizerCurrPoint;
                m_objMustLinksCurrPointBest = m_objMustLinksCurrPoint;
                m_objCannotLinksCurrPointBest = m_objCannotLinksCurrPoint;
            }
        }

        m_objVariance += m_objVarianceCurrPointBest;
        m_objNormalizer += m_objNormalizerCurrPointBest;
        m_objMustLinks += m_objMustLinksCurrPointBest;
        m_objCannotLinks += m_objCannotLinksCurrPointBest;

        if (m_ClusterAssignments[instIdx] != bestCluster) {
            if (m_ClusterAssignments[instIdx] >= 0 && m_ClusterAssignments[instIdx] < m_NumClusters) {
                //if (m_verbose) {
                System.out.println("Moving instance " + instIdx + " from cluster " + m_ClusterAssignments[instIdx]
                        + " to cluster " + bestCluster + " penalty:"
                        + ((float) penaltyForInstance(instIdx, m_ClusterAssignments[instIdx])) + "=>"
                        + ((float) lowestPenalty));
            }
            moved = 1;
            m_ClusterAssignments[instIdx] = bestCluster;
        }

        if (m_verbose) {
            System.out.println("Assigning instance " + instIdx + " to cluster " + bestCluster);
        }

        return moved;
    }

    /** Delegate the distance calculation to the method appropriate for the current metric
     */
    public double penaltyForInstance(int instIdx, int centroidIdx) throws Exception {
        m_objVarianceCurrPoint = 0;
        m_objCannotLinksCurrPoint = 0;
        m_objMustLinksCurrPoint = 0;
        m_objNormalizerCurrPoint = 0;
        int violatedConstraints = 0;

        // variance contribution
        Instance instance = m_Instances.instance(instIdx);
        Instance centroid = m_ClusterCentroids.instance(centroidIdx);

        m_objVarianceCurrPoint = m_metrics[centroidIdx].penalty(instance, centroid);

        // regularizer and normalizer contribution
        if (m_Trainable == TRAINING_INTERNAL) {
            m_objNormalizerCurrPoint = -m_logTerms[centroidIdx];
        }

        // only add the constraints if seedable or constrained
        //    if (m_Seedable || (m_Trainable != TRAINING_NONE)) {   

        // Sugato: replacing, in order to be able to run MKMeans (no
        // constraint violation, only metric learning)
        if (m_Seedable) {
            Object list = m_instanceConstraintHash.get(new Integer(instIdx));
            if (list != null) { // there are constraints associated with this instance
                ArrayList constraintList = (ArrayList) list;
                for (int i = 0; i < constraintList.size(); i++) {
                    InstancePair pair = (InstancePair) constraintList.get(i);
                    int firstIdx = pair.first;
                    int secondIdx = pair.second;

                    Instance instance1 = m_Instances.instance(firstIdx);
                    Instance instance2 = m_Instances.instance(secondIdx);
                    int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx]
                            : m_ClusterAssignments[firstIdx];

                    // check whether the constraint is violated
                    if (otherIdx != -1 && otherIdx < m_NumClusters) {
                        if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) {
                            violatedConstraints++;
                            // split penalty in half between the two involved clusters
                            if (m_useMultipleMetrics) {
                                double penalty1 = m_metrics[otherIdx].penaltySymmetric(instance1, instance2);
                                double penalty2 = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2);
                                m_objMustLinksCurrPoint += 0.5 * m_MLweight * (penalty1 + penalty2);
                            } else {
                                double penalty = m_metric.penaltySymmetric(instance1, instance2);
                                m_objMustLinksCurrPoint += m_MLweight * penalty;
                            }
                        } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) {
                            violatedConstraints++;
                            double penalty = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2);
                            m_objCannotLinksCurrPoint += m_CLweight * (m_maxCLPenalties[centroidIdx] - penalty);
                            if (m_maxCLPenalties[centroidIdx] - penalty < 0) {
                                System.out.println("***NEGATIVE*** penalty: " + penalty + " for CL constraint");
                            }
                        }
                    }
                }
            }
        }

        double total = m_objVarianceCurrPoint + m_objCannotLinksCurrPoint + m_objMustLinksCurrPoint
                + m_objNormalizerCurrPoint;
        if (m_verbose) {
            System.out.println(
                    "Final penalty for instance " + instIdx + " and centroid " + centroidIdx + " is: " + total);
        }
        return total;
    }

    /** M-step of the KMeans clustering algorithm -- updates cluster centroids
     */
    protected void updateClusterCentroids() throws Exception {
        Instances[] tempI = new Instances[m_NumClusters];
        Instances tempCentroids = m_ClusterCentroids;
        Instances tempNewCentroids = new Instances(m_Instances, m_NumClusters);
        m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);

        // tempI[i] stores the cluster instances for cluster i
        for (int i = 0; i < m_NumClusters; i++) {
            tempI[i] = new Instances(m_Instances, 0);
        }
        for (int i = 0; i < m_Instances.numInstances(); i++) {
            tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i));
        }

        // Calculates cluster centroids
        for (int i = 0; i < m_NumClusters; i++) {
            double[] values = new double[m_Instances.numAttributes()];
            Instance centroid = null;

            if (m_isSparseInstance) { // uses fast meanOrMode
                values = ClusterUtils.meanOrMode(tempI[i]);
                centroid = new SparseInstance(1.0, values);
            } else { // non-sparse, go through each attribute
                for (int j = 0; j < m_Instances.numAttributes(); j++) {
                    values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode
                }
                centroid = new Instance(1.0, values);
            }

            //        // debugging:  compare  previous centroid w/current:
            //        double w = 0; 
            //        for (int j = 0; j < m_Instances.numAttributes(); j++)  w += values[j] * values[j];
            //        double w1 = 0; 
            //        for (int j = 0; j < m_Instances.numAttributes(); j++)  w1 += tempCentroids.instance(i).value(j) * tempCentroids.instance(i).value(j);

            //        System.out.println("\tOldCentroid=" + w1);
            //        System.out.println("\tNewCentroid=" + w); 
            //        double prevObj = 0, currObj = 0;
            //        for (int j = 0; j < tempI[i].numInstances(); j++) {
            //     Instance instance = tempI[i].instance(j);
            //     double prevPen = m_metrics[i].penalty(instance, tempCentroids.instance(i));
            //     double currPen = m_metrics[i].penalty(instance, centroid);
            //     prevObj += prevPen;
            //     currObj += currPen; 
            //     //System.out.println("\t\t" + j + " " + prevPen + " -> " + currPen + "\t" + prevObj + " -> " + currObj); 
            //        }
            //        // dump instances out if there is a problem.
            //        System.out.println("\t\t" + prevObj + " -> " + currObj); 
            //        if (currObj > prevObj) {

            //     PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream("/tmp/INST.arff")), true);
            //     out.println(new Instances(tempI[i], 0));
            //     out.println(centroid);
            //     out.println(tempCentroids.instance(i)); 
            //     for (int j = 0; j < tempI[i].numInstances(); j++) {
            //       out.println(tempI[i].instance(j));
            //     }
            //     out.close();
            //     System.out.println("  Updated cluster " + i + "("
            //              + tempI[i].numInstances());
            //     System.exit(0); 
            //        } 

            // if we are using a smoothing metric, smooth the centroids
            if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) {
                System.out.println("\tSmoothing...");
                SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric;
                centroid = smoothingMetric.smoothInstance(centroid);
            }

            //   DEBUGGING:  replaced line under with block below
            m_ClusterCentroids.add(centroid);
            //        {
            //     tempNewCentroids.add(centroid);
            //     m_ClusterCentroids.delete(); 
            //     for (int j = 0; j <= i; j++) {
            //       m_ClusterCentroids.add(tempNewCentroids.instance(j));
            //     }
            //     for (int j = i+1; j < m_NumClusters; j++) {
            //       m_ClusterCentroids.add(tempCentroids.instance(j));
            //     } 
            //     double objBackup = m_Objective;
            //     System.out.println("  Updated cluster " + i + "("
            //              + tempI[i].numInstances() + "); obj=" +
            //              calculateObjectiveFunction(false));
            //     m_Objective = objBackup;
            //        }

            // in SPKMeans, cluster centroids need to be normalized
            if (m_metric.doesNormalizeData()) {
                m_metric.normalizeInstanceWeighted(m_ClusterCentroids.instance(i));
            }
        }

        if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing())
            updateSmoothingMetrics();

        for (int i = 0; i < m_NumClusters; i++)
            tempI[i] = null; // free memory
    }

    /** M-step of the KMeans clustering algorithm -- updates metric
     *  weights. Invoked only when we're using non-Potts model
     *  and metric is trainable
     */
    protected void updateMetricWeights() throws Exception {
        if (m_useMultipleMetrics) {
            for (int i = 0; i < m_NumClusters; i++) {
                m_metricLearners[i].trainMetric(i);
            }
        } else {
            m_metricLearner.trainMetric(-1);
        }
        InitNormalizerRegularizer();
    }

    /** checks for convergence */
    public boolean convergenceCheck(double oldObjective, double newObjective) throws Exception {
        boolean converged = false;

        // Convergence check
        if (Math.abs(oldObjective - newObjective) < m_ObjFunConvergenceDifference) {
            System.out.println("Final objective function is: " + newObjective);
            converged = true;
        }

        // number of iterations check
        if (m_numBlankIterations >= m_maxBlankIterations) {
            System.out.println("Max blank iterations reached ...\n");
            System.out.println("Final objective function is: " + newObjective);
            converged = true;
        }
        if (m_Iterations >= m_maxIterations) {
            System.out.println("Max iterations reached ...\n");
            System.out.println("Final objective function is: " + newObjective);
            converged = true;
        }

        return converged;
    }

    /** calculates objective function */
    public double calculateObjectiveFunction(boolean isComplete) throws Exception {
        System.out.println("\tCalculating objective function ...");

        // update the oldObjective only if previous estimate of m_Objective
        // was complete
        if (isComplete) {
            m_OldObjective = m_Objective;
        }
        m_Objective = 0;
        m_objVariance = 0;
        m_objMustLinks = 0;
        m_objCannotLinks = 0;
        m_objNormalizer = 0;

        // Some debugging code:  tracking per-cluster objective
        double[] objectives = new double[m_NumClusters];

        // temporarily halve weights since every constraint is counted twice
        double tempML = m_MLweight;
        double tempCL = m_CLweight;
        m_MLweight = tempML / 2;
        m_CLweight = tempCL / 2;

        if (m_verbose) {
            System.out.println("Must link weight: " + m_MLweight);
            System.out.println("Cannot link weight: " + m_CLweight);
        }

        for (int i = 0; i < m_Instances.numInstances(); i++) {
            if (m_isOfflineMetric) {
                double dist = m_metric.penalty(m_Instances.instance(i),
                        m_ClusterCentroids.instance(m_ClusterAssignments[i]));
                m_Objective += dist;
                if (m_verbose) {
                    System.out.println("Component for " + i + " = " + dist);
                }
            } else {
                double penalty = penaltyForInstance(i, m_ClusterAssignments[i]);
                objectives[m_ClusterAssignments[i]] += penalty;
                m_Objective += penalty;
                m_objVariance += m_objVarianceCurrPoint;
                m_objMustLinks += m_objMustLinksCurrPoint;
                m_objCannotLinks += m_objCannotLinksCurrPoint;
                m_objNormalizer += m_objNormalizerCurrPoint;
            }
        }

        m_Objective -= m_objRegularizer;

        m_MLweight = tempML;
        m_CLweight = tempCL; // reset the values of the constraint weights

        // debugging:  reporting per-cluster objectives
        for (int i = 0; i < m_NumClusters; i++) {
            System.out.println("\t\tCluster " + i + " obj=" + objectives[i]);
        }
        System.out.println("\tTotalObj=" + m_Objective);

        // Oscillation check
        if ((float) m_OldObjective < (float) m_Objective) {
            System.out.println("WHOA!!!  Oscillations => bug in EM step?");
            System.out.println(
                    "Old objective:" + (float) m_OldObjective + " < New objective: " + (float) m_Objective);
        }

        //      // TEMPORARY BLAH
        //      System.out.println("\tvar=" + ((float)m_objVariance)
        //            + "\tC=" + ((float)m_objCannotLinks)
        //            + "\tM=" + ((float)m_objMustLinks)
        //            + "\tLOG=" + ((float)m_objNormalizer) 
        //            + "\tREG=" + ((float)m_objRegularizer));

        return m_Objective;
    }

    /** Actual KMeans function */
    protected void runKMeans() throws Exception {
        boolean converged = false;
        m_Iterations = 0;
        m_numBlankIterations = 0;
        m_Objective = Double.POSITIVE_INFINITY;

        if (!m_isOfflineMetric) {
            if (m_useMultipleMetrics) {
                for (int i = 0; i < m_metrics.length; i++) {
                    m_metrics[i].resetMetric();
                    m_metricLearners[i].resetLearner();
                }
            } else {
                m_metric.resetMetric();
                m_metricLearner.resetLearner();
            }
            // initialize max CL penalties
            if (m_ConstraintsHash.size() > 0) {
                m_maxCLPenalties = calculateMaxCLPenalties();
            }
        }

        // initialize m_ClusterAssignments
        for (int i = 0; i < m_NumClusters; i++) {
            m_ClusterAssignments[i] = -1;
        }

        PrintStream fincoh = null;
        if (m_ConstraintIncoherenceFile != null) {
            fincoh = new PrintStream(new FileOutputStream(m_ConstraintIncoherenceFile));
        }

        while (!converged) {
            System.out.println("\n" + m_Iterations + ". Objective function: " + ((float) m_Objective));
            m_OldObjective = m_Objective;

            // E-step
            int numMovedPoints = findBestAssignments();

            m_numBlankIterations = (numMovedPoints == 0) ? m_numBlankIterations + 1 : 0;

            //      calculateObjectiveFunction(false);
            System.out.println((float) m_Objective + " - Objective function after point assignment(CALC)");
            System.out.println("\tvar=" + ((float) m_objVariance) + "\tC=" + ((float) m_objCannotLinks) + "\tM="
                    + ((float) m_objMustLinks) + "\tLOG=" + ((float) m_objNormalizer) + "\tREG="
                    + ((float) m_objRegularizer));

            // M-step
            updateClusterCentroids();

            //      calculateObjectiveFunction(false);
            System.out.println((float) m_Objective + " - Objective function after centroid estimation");
            System.out.println("\tvar=" + ((float) m_objVariance) + "\tC=" + ((float) m_objCannotLinks) + "\tM="
                    + ((float) m_objMustLinks) + "\tLOG=" + ((float) m_objNormalizer) + "\tREG="
                    + ((float) m_objRegularizer));

            if (m_Trainable == TRAINING_INTERNAL && !m_isOfflineMetric) {
                updateMetricWeights();
                if (m_verbose) {
                    calculateObjectiveFunction(true);
                    System.out.println((float) m_Objective + " - Objective function after metric update");
                    System.out.println("\tvar=" + ((float) m_objVariance) + "\tC=" + ((float) m_objCannotLinks)
                            + "\tM=" + ((float) m_objMustLinks) + "\tLOG=" + ((float) m_objNormalizer) + "\tREG="
                            + ((float) m_objRegularizer));
                }

                if (m_ConstraintsHash.size() > 0) {
                    m_maxCLPenalties = calculateMaxCLPenalties();
                }
            }

            if (fincoh != null) {
                printConstraintIncoherence(fincoh);
            }

            converged = convergenceCheck(m_OldObjective, m_Objective);
            m_Iterations++;
        }

        if (fincoh != null) {
            fincoh.close();
        }
        System.out.println("Converged!");
        System.err.print("Its\t" + m_Iterations + "\t");

        if (m_verbose) {
            System.out.println("Done clustering; top cluster features: ");
            for (int i = 0; i < m_NumClusters; i++) {
                System.out.println("Centroid " + i);
                TreeMap map = new TreeMap(Collections.reverseOrder());
                Instance centroid = m_ClusterCentroids.instance(i);
                for (int j = 0; j < centroid.numValues(); j++) {
                    Attribute attr = centroid.attributeSparse(j);
                    map.put(new Double(centroid.value(attr)), attr.name());
                }
                Iterator it = map.entrySet().iterator();
                for (int j = 0; j < 5 && it.hasNext(); j++) {
                    Map.Entry entry = (Map.Entry) it.next();
                    System.out.println("\t" + entry.getKey() + "\t" + entry.getValue());
                }
            }
        }
    }

    public void printConstraintIncoherence(PrintStream fincoh) throws Exception {
        Object[] array = m_ConstraintsHash.entrySet().toArray();

        int numML = 0, numCL = 0;
        double incoh = 0;

        m_numViolations = 0;

        System.out.println("NumConstraints: " + array.length);
        for (int i = 0; i < array.length; i++) {
            Map.Entry con1 = (Map.Entry) array[i];
            InstancePair pair1 = (InstancePair) con1.getKey();
            int link1 = ((Integer) con1.getValue()).intValue();
            double dist1 = m_metric.distance(m_Instances.instance(pair1.first), m_Instances.instance(pair1.second));
            if (link1 == InstancePair.MUST_LINK) {
                numML++;
            } else if (link1 == InstancePair.CANNOT_LINK) {
                numCL++;
            }

            for (int j = i + 1; j < array.length; j++) {
                Map.Entry con2 = (Map.Entry) array[j];
                InstancePair pair2 = (InstancePair) con2.getKey();
                int link2 = ((Integer) con2.getValue()).intValue();
                double dist2 = m_metric.distance(m_Instances.instance(pair2.first),
                        m_Instances.instance(pair2.second));

                if (link1 == InstancePair.MUST_LINK) {
                    if (link2 == InstancePair.CANNOT_LINK) {
                        if (dist1 > dist2) {
                            m_numViolations++;
                            //           System.out.println("(" + pair1.first + "," + pair1.second + "): " + link1 + ":" + dist1);
                            //           System.out.println("(" + pair2.first + "," + pair2.second + "): " + link2 + ":" + dist2);
                            //           System.out.println("Violations: " + m_numViolations);
                        }
                    }
                } else if (link1 == InstancePair.CANNOT_LINK) {
                    if (link2 == InstancePair.MUST_LINK) {
                        if (dist1 < dist2) {
                            m_numViolations++;
                            //           System.out.println("(" + pair1.first + "," + pair1.second + "): " + link1 + ":" + dist1);
                            //           System.out.println("(" + pair2.first + "," + pair2.second + "): " + link2 + ":" + dist2);
                            //           System.out.println("Violations: " + m_numViolations);
                        }
                    }
                }
            }
        }

        incoh = (m_numViolations * 1.0) / (numCL * numML);

        if (fincoh != null) {
            //     fincoh.println((m_Iterations+1)  + "\tNumViolations\t" + m_numViolations + "\tNumTotalCL\t" + numCL + "\tNumTotalML\t" + numML);
            fincoh.println("Iterations\t" + (m_Iterations + 1) + "\tIncoh\t" + incoh);
        } else {
            System.out.println((m_Iterations + 1) + "\tNumViolations\t" + m_numViolations + "\tNumTotalCL\t" + numCL
                    + "\tNumTotalML\t" + numML);
        }
    }

    /** reset the value of the objective function and all of its components */
    public void resetObjective() {
        m_Objective = 0;
        m_objVariance = 0;
        m_objCannotLinks = 0;
        m_objMustLinks = 0;
        m_objNormalizer = 0;
        m_objRegularizer = 0;
    }

    /** Go through the cannot-link constraints and find the current maximum distance
     * @return an array of maximum weighted distances.  If a single metric is used, maximum distance
     * is calculated over the entire dataset */
    // TODO:  non-datasetWide case is not debugged currently!!!
    protected double[] calculateMaxCLPenalties() throws Exception {
        double[] maxPenalties = null;
        double[][] minValues = null;
        double[][] maxValues = null;
        int[] attrIdxs = null;

        maxPenalties = new double[m_NumClusters];
        m_maxCLPoints = new Instance[m_NumClusters][2];
        m_maxCLDiffInstances = new Instance[m_NumClusters];

        for (int i = 0; i < m_NumClusters; i++) {
            m_maxCLPoints[i][0] = new Instance(m_Instances.numAttributes());
            m_maxCLPoints[i][1] = new Instance(m_Instances.numAttributes());
            m_maxCLPoints[i][0].setDataset(m_Instances);
            m_maxCLPoints[i][1].setDataset(m_Instances);
            m_maxCLDiffInstances[i] = new Instance(m_Instances.numAttributes());
            m_maxCLDiffInstances[i].setDataset(m_Instances);
        }

        // TEMPORARY PLUG:  this was supposed to take care of WeightedDotp,
        // but it turns out that with weighting similarity can be > 1. 
        //      if (m_metric.m_fixedMaxDistance) {
        //        for (int i = 0; i < m_NumClusters; i++) {
        //     maxPenalties[i] = m_metric.getMaxDistance(); 
        //        }
        //        return maxPenalties; 
        //      } 

        minValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()];
        maxValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()];
        attrIdxs = m_metrics[0].getAttrIndxs();

        // temporary plug:  if this if the first iteration when no instances were assigned to clusters,
        // dataset-wide (not cluster-wide!) minimum and maximum are used even for the case with
        // multiple metrics
        boolean datasetWide = true;
        if (m_useMultipleMetrics && m_Iterations > 0) {
            datasetWide = false;
        }

        // TODO:  Mahalanobis - check with getMaxPoints
        // go through all points
        if (m_metric instanceof WeightedMahalanobis) {
            if (m_useMultipleMetrics) {
                for (int i = 0; i < m_metrics.length; i++) {
                    double[][] maxPoints = ((WeightedMahalanobis) m_metrics[i]).getMaxPoints(m_ConstraintsHash,
                            m_Instances);
                    minValues[i] = maxPoints[0];
                    maxValues[i] = maxPoints[1];
                    //       System.out.println("Max points " + i);
                    //       for (int j = 0; j < maxPoints[0].length; j++) { System.out.println(maxPoints[0][j] + " - " + maxPoints[1][j]);}
                }
            } else {
                double[][] maxPoints = ((WeightedMahalanobis) m_metric).getMaxPoints(m_ConstraintsHash,
                        m_Instances);
                minValues[0] = maxPoints[0];
                maxValues[0] = maxPoints[1];
                for (int i = 0; i < m_metrics.length; i++) {
                    minValues[i] = maxPoints[0];
                    maxValues[i] = maxPoints[1];
                }
                //     System.out.println("Max points:");
                //     for (int i = 0; i < maxPoints[0].length; i++) { System.out.println(maxPoints[0][i] + " - " + maxPoints[1][i]);}
            }
        } else { // find the enclosing hypercube for WeightedEuclidean etc. 
            for (int i = 0; i < m_Instances.numInstances(); i++) {
                Instance instance = m_Instances.instance(i);
                for (int j = 0; j < attrIdxs.length; j++) {
                    double val = instance.value(attrIdxs[j]);
                    if (datasetWide) {
                        if (val < minValues[0][j]) {
                            minValues[0][j] = val;
                        }
                        if (val > maxValues[0][j]) {
                            maxValues[0][j] = val;
                        }
                    } else { // cluster-specific min's and max's  are needed
                        if (val < minValues[m_ClusterAssignments[i]][j]) {
                            minValues[m_ClusterAssignments[i]][j] = val;
                        }
                        if (val > maxValues[m_ClusterAssignments[i]][j]) {
                            maxValues[m_ClusterAssignments[i]][j] = val;
                        }
                    }
                }
            }
        }

        // get the max/min points
        if (datasetWide) {
            for (int i = 0; i < attrIdxs.length; i++) {
                m_maxCLPoints[0][0].setValue(attrIdxs[i], minValues[0][i]);
                m_maxCLPoints[0][1].setValue(attrIdxs[i], maxValues[0][i]);
            }
            // must copy these over all clusters - just for the first iteration
            for (int j = 1; j < m_NumClusters; j++) {
                for (int i = 0; i < attrIdxs.length; i++) {
                    m_maxCLPoints[j][0].setValue(attrIdxs[i], minValues[0][i]);
                    m_maxCLPoints[j][1].setValue(attrIdxs[i], maxValues[0][i]);
                }
            }
        } else { // cluster-specific
            for (int j = 0; j < m_NumClusters; j++) {
                for (int i = 0; i < attrIdxs.length; i++) {
                    m_maxCLPoints[j][0].setValue(attrIdxs[i], minValues[j][i]);
                    m_maxCLPoints[j][1].setValue(attrIdxs[i], maxValues[j][i]);
                }
            }
        }

        // calculate the distances
        if (datasetWide) {
            maxPenalties[0] = m_metrics[0].penaltySymmetric(m_maxCLPoints[0][0], m_maxCLPoints[0][1]);
            m_maxCLDiffInstances[0] = m_metrics[0].createDiffInstance(m_maxCLPoints[0][0], m_maxCLPoints[0][1]);
            for (int i = 1; i < maxPenalties.length; i++) {
                maxPenalties[i] = maxPenalties[0];
                m_maxCLDiffInstances[i] = m_maxCLDiffInstances[0];
            }
        } else { // cluster-specific - SHOULD BE FIXED!!!!
            for (int j = 0; j < m_NumClusters; j++) {
                for (int i = 0; i < attrIdxs.length; i++) {
                    maxPenalties[j] += m_metrics[j].penaltySymmetric(m_maxCLPoints[j][0], m_maxCLPoints[j][1]);
                    m_maxCLDiffInstances[j] = m_metrics[0].createDiffInstance(m_maxCLPoints[j][0],
                            m_maxCLPoints[j][1]);
                }
            }
        }
        System.out.println("Recomputed max CL penalties");
        return maxPenalties;
    }

    /**
     * Checks if instance has to be normalized and classifies the
     * instance using the current clustering
     *
     * @param instance the instance to be assigned to a cluster
     * @return the number of the assigned cluster as an integer
     * if the class is enumerated, otherwise the predicted value
     * @exception Exception if instance could not be classified
     * successfully */

    public int clusterInstance(Instance instance) throws Exception {
        return assignInstanceToCluster(instance);
    }

    /** lookup the instance in the checksum hash, assuming transductive clustering
     * @param instance instance to be looked up
     * @return the index of the cluster to which the instance was assigned, -1 if the instance has not bee clustered
     */
    protected int lookupInstanceCluster(Instance instance) throws Exception {
        int classIdx = instance.classIndex();
        double checksum = 0;

        // need to normalize using original metric, since cluster data is normalized similarly
        if (m_metric.doesNormalizeData()) {
            if (m_Trainable == TRAINING_INTERNAL) {
                m_metric.resetMetric();
            }
            m_metric.normalizeInstanceWeighted(instance);
        }

        double[] values1 = instance.toDoubleArray();
        for (int i = 0; i < values1.length; i++) {
            if (i != classIdx) {
                checksum += m_checksumCoeffs[i] * values1[i];
            }
        }

        Object list = m_checksumHash.get(new Double((float) checksum));
        if (list != null) {
            // go through the list of instances with the same checksum and find the one that is equivalent
            ArrayList checksumList = (ArrayList) list;
            for (int i = 0; i < checksumList.size(); i++) {
                int instanceIdx = ((Integer) checksumList.get(i)).intValue();
                Instance listInstance = m_Instances.instance(instanceIdx);
                double[] values2 = listInstance.toDoubleArray();
                boolean equal = true;
                for (int j = 0; j < values1.length && equal == true; j++) {
                    if (j != classIdx) {
                        if ((float) values1[j] != (float) values2[j]) {
                            equal = false;
                        }
                    }
                }
                if (equal == true) {
                    return m_ClusterAssignments[instanceIdx];
                }
            }
        }
        return -1;
    }

    /**
     * Classifies the instances using the current clustering, moves
     * must-linked points together (Xing's approach)
     *
     * @param instIdx the instance index to be assigned to a cluster
     * @return the number of the assigned cluster as an integer
     * if the class is enumerated, otherwise the predicted value
     * @exception Exception if instance could not be classified
     * successfully */

    public int assignAllInstancesToClusters() throws Exception {
        int numInstances = m_Instances.numInstances();
        boolean[] instanceAlreadyAssigned = new boolean[numInstances];
        int moved = 0;

        if (!m_isOfflineMetric) {
            System.err.println(
                    "WARNING!!!\n\nThis code should not be called if metric is not a BarHillelMetric or XingMetric!!!!\n\n");
        }

        for (int i = 0; i < numInstances; i++) {
            instanceAlreadyAssigned[i] = false;
        }

        // now process points not in ML meighborhood sets
        for (int instIdx = 0; instIdx < numInstances; instIdx++) {
            if (instanceAlreadyAssigned[instIdx]) {
                continue; // was already in some ML neighborhood
            }
            int bestCluster = 0;
            double bestDistance = Double.POSITIVE_INFINITY;
            for (int centroidIdx = 0; centroidIdx < m_NumClusters; centroidIdx++) {
                double sqDistance = m_metric.distance(m_Instances.instance(instIdx),
                        m_ClusterCentroids.instance(centroidIdx));
                if (sqDistance < bestDistance) {
                    bestDistance = sqDistance;
                    bestCluster = centroidIdx;
                }
            }

            // accumulate objective function value
            //      m_Objective += bestDistance;

            // do we need to reassign the point?
            if (m_ClusterAssignments[instIdx] != bestCluster) {
                m_ClusterAssignments[instIdx] = bestCluster;
                instanceAlreadyAssigned[instIdx] = true;
                moved++;
            }
        }
        return moved;
    }

    /**
     * Classifies the instance using the current clustering, without considering constraints
     *
     * @param instance the instance to be assigned to a cluster
     * @return the number of the assigned cluster as an integer
     * if the class is enumerated, otherwise the predicted value
     * @exception Exception if instance could not be classified
     * successfully */

    public int assignInstanceToCluster(Instance instance) throws Exception {
        int bestCluster = 0;
        double bestDistance = Double.POSITIVE_INFINITY;
        double bestSimilarity = Double.NEGATIVE_INFINITY;
        int lookupCluster;

        if (m_metric instanceof InstanceConverter) {
            Instance newInstance = ((InstanceConverter) m_metric).convertInstance(instance);
            lookupCluster = lookupInstanceCluster(newInstance);
        } else {
            lookupCluster = lookupInstanceCluster(instance);
        }
        if (lookupCluster >= 0) {
            return lookupCluster;
        }
        throw new Exception(
                "ACHTUNG!!!\n\nCouldn't lookup the instance!!! Size of hash = " + m_checksumHash.size());
    }

    /** Set the cannot link constraint weight */
    public void setCannotLinkWeight(double w) {
        m_CLweight = w;
    }

    /** Return the cannot link constraint weight */
    public double getCannotLinkWeight() {
        return m_CLweight;
    }

    /** Set the must link constraint weight */
    public void setMustLinkWeight(double w) {
        m_MLweight = w;
    }

    /** Return the must link constraint weight */
    public double getMustLinkWeight() {
        return m_MLweight;
    }

    /** Return the number of clusters */
    public int getNumClusters() {
        return m_NumClusters;
    }

    /** A duplicate function to conform to Clusterer abstract class.
     * @returns the number of clusters
     */
    public int numberOfClusters() {
        return getNumClusters();
    }

    /** Set the m_SeedHash */
    public void setSeedHash(HashMap seedhash) {
        System.err.println("Not implemented here");
    }

    /**
     * Set the random number seed
     * @param s the seed
     */
    public void setRandomSeed(int s) {
        m_RandomSeed = s;
    }

    /** Return the random number seed */
    public int getRandomSeed() {
        return m_RandomSeed;
    }

    /** Set the maximum number of iterations */
    public void setMaxIterations(int maxIterations) {
        m_maxIterations = maxIterations;
    }

    /** Get the maximum number of iterations */
    public int getMaxIterations() {
        return m_maxIterations;
    }

    /** Set the maximum number of blank iterations (those where no points are moved) */
    public void setMaxBlankIterations(int maxBlankIterations) {
        m_maxBlankIterations = maxBlankIterations;
    }

    /** Get the maximum number of blank iterations */
    public int getMaxBlankIterations() {
        return m_maxBlankIterations;
    }

    /**
     * Set the minimum value of the objective function difference required for convergence
     * @param objFunConvergenceDifference the minimum value of the objective function difference required for convergence
     */
    public void setObjFunConvergenceDifference(double objFunConvergenceDifference) {
        m_ObjFunConvergenceDifference = objFunConvergenceDifference;
    }

    /**
     * Get the minimum value of the objective function difference required for convergence
     * @returns the minimum value of the objective function difference required for convergence
     */
    public double getObjFunConvergenceDifference() {
        return m_ObjFunConvergenceDifference;
    }

    /** Sets training instances */
    public void setInstances(Instances instances) {
        m_Instances = instances;

        // create the checksum coefficients
        m_checksumCoeffs = new double[instances.numAttributes()];
        for (int i = 0; i < m_checksumCoeffs.length; i++) {
            m_checksumCoeffs[i] = m_RandomNumberGenerator.nextDouble();
        }

        // hash the instance checksums
        m_checksumHash = new HashMap(instances.numInstances());
        int classIdx = instances.classIndex();
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            double[] values = instance.toDoubleArray();
            double checksum = 0;

            for (int j = 0; j < values.length; j++) {
                if (j != classIdx) {
                    checksum += m_checksumCoeffs[j] * values[j];
                }
            }

            // take care of chaining
            Object list = m_checksumHash.get(new Double((float) checksum));
            ArrayList idxList = null;
            if (list == null) {
                idxList = new ArrayList();
                m_checksumHash.put(new Double((float) checksum), idxList);
            } else { // chaining
                idxList = (ArrayList) list;
            }
            idxList.add(new Integer(i));
        }
    }

    /** Return training instances */
    public Instances getInstances() {
        return m_Instances;
    }

    /**
     * Set the number of clusters to generate
     *
     * @param n the number of clusters to generate
     */
    public void setNumClusters(int n) {
        m_NumClusters = n;
        if (m_verbose) {
            System.out.println("Number of clusters: " + n);
        }
    }

    /** Is the objective function decreasing or increasing? */
    public boolean isObjFunDecreasing() {
        return m_objFunDecreasing;
    }

    /**
     * Set the distance metric
     *
     * @param s the metric
     */
    public void setMetric(LearnableMetric m) {
        String metricName = m.getClass().getName();
        m_metric = m;
        m_metricLearner.setMetric(m_metric);
        m_metricLearner.setClusterer(this);
    }

    /**
     * get the distance metric
     * @returns the distance metric used
     */
    public LearnableMetric getMetric() {
        return m_metric;
    }

    /**
     * get the array of metrics
     */
    public LearnableMetric[] getMetrics() {
        return m_metrics;
    }

    /** Set/get the metric learner */
    public void setMetricLearner(MPCKMeansMetricLearner ml) {
        m_metricLearner = ml;
        m_metricLearner.setMetric(m_metric);
        m_metricLearner.setClusterer(this);
    }

    public MPCKMeansMetricLearner getMetricLearner() {
        return m_metricLearner;
    }

    /** Set/get the assigner */
    public MPCKMeansAssigner getAssigner() {
        return m_Assigner;
    }

    public void setAssigner(MPCKMeansAssigner assigner) {
        assigner.setClusterer(this);
        this.m_Assigner = assigner;
    }

    /** Set/get the initializer */
    public MPCKMeansInitializer getInitializer() {
        return m_Initializer;
    }

    public void setInitializer(MPCKMeansInitializer initializer) {
        initializer.setClusterer(this);
        this.m_Initializer = initializer;
    }

    /** Read the seeds from a hastable, where every key is an instance and every value is:
     * the cluster assignment of that instance 
     * seedVector vector containing seeds
     */

    public void seedClusterer(HashMap seedHash) {
        System.err.println("Not implemented here");
    }

    public void printClusterAssignments() throws Exception {
        if (m_ClusterAssignmentsOutputFile != null) {
            PrintStream p = new PrintStream(new FileOutputStream(m_ClusterAssignmentsOutputFile));

            for (int i = 0; i < m_Instances.numInstances(); i++) {
                p.println(i + "\t" + m_ClusterAssignments[i]);
            }
            p.close();
        } else {
            System.out.println("\nCluster Assignments:\n");
            for (int i = 0; i < m_Instances.numInstances(); i++) {
                System.out.println(i + "\t" + m_ClusterAssignments[i]);
            }
        }
    }

    /** Prints clusters */
    public void printClusters() throws Exception {
        ArrayList clusters = getClusters();
        for (int i = 0; i < clusters.size(); i++) {
            Cluster currentCluster = (Cluster) clusters.get(i);
            System.out.println("\nCluster " + i + ": " + currentCluster.size() + " instances");
            if (currentCluster == null) {
                System.out.println("(empty)");
            } else {
                for (int j = 0; j < currentCluster.size(); j++) {
                    Instance instance = (Instance) currentCluster.get(j);
                    System.out.println("Instance: " + instance);
                }
            }
        }
    }

    /**
     * Computes the clusters from the cluster assignments, for external access
     * 
     * @exception Exception if clusters could not be computed successfully
     */

    public ArrayList getClusters() throws Exception {
        m_Clusters = new ArrayList();
        Cluster[] clusterArray = new Cluster[m_NumClusters];

        for (int i = 0; i < m_Instances.numInstances(); i++) {
            Instance inst = m_Instances.instance(i);
            if (clusterArray[m_ClusterAssignments[i]] == null)
                clusterArray[m_ClusterAssignments[i]] = new Cluster();
            clusterArray[m_ClusterAssignments[i]].add(inst, 1);
        }

        for (int j = 0; j < m_NumClusters; j++)
            m_Clusters.add(clusterArray[j]);

        return m_Clusters;
    }

    /**
     * Computes the clusters from the cluster assignments, for external access
     * 
     * @exception Exception if clusters could not be computed successfully
     */

    public HashSet[] getIndexClusters() throws Exception {
        m_IndexClusters = new HashSet[m_NumClusters];
        for (int i = 0; i < m_Instances.numInstances(); i++) {
            if (m_verbose) {
                //   System.out.println("In getIndexClusters, " + i + " assigned to cluster " + m_ClusterAssignments[i]);
            }
            if (m_ClusterAssignments[i] != -1 && m_ClusterAssignments[i] < m_NumClusters) {
                if (m_IndexClusters[m_ClusterAssignments[i]] == null) {
                    m_IndexClusters[m_ClusterAssignments[i]] = new HashSet();
                }
                m_IndexClusters[m_ClusterAssignments[i]].add(new Integer(i));
            }
        }
        return m_IndexClusters;
    }

    public Enumeration listOptions() {
        return null;
    }

    public String[] getOptions() {
        String[] options = new String[150];
        int current = 0;

        if (!m_Seedable) {
            options[current++] = "-X";
        }

        if (m_Trainable != TRAINING_NONE) {
            options[current++] = "-T";
            if (m_Trainable == TRAINING_INTERNAL) {
                options[current++] = "Int";
            } else {
                options[current++] = "Ext";
            }
        }

        options[current++] = "-M";
        options[current++] = Utils.removeSubstring(m_metric.getClass().getName(), "weka.core.metrics.");
        if (m_metric instanceof OptionHandler) {
            String[] metricOptions = ((OptionHandler) m_metric).getOptions();
            for (int i = 0; i < metricOptions.length; i++) {
                options[current++] = metricOptions[i];
            }
        }

        if (m_Trainable != TRAINING_NONE) {
            options[current++] = "-L";
            options[current++] = Utils.removeSubstring(m_metricLearner.getClass().getName(),
                    "weka.clusterers.metriclearners.");
            String[] metricLearnerOptions = ((OptionHandler) m_metricLearner).getOptions();
            for (int i = 0; i < metricLearnerOptions.length; i++) {
                options[current++] = metricLearnerOptions[i];
            }
        }

        if (m_regularize) {
            options[current++] = "-G";
            options[current++] = Utils.removeSubstring(m_metric.getRegularizer().getClass().getName(),
                    "weka.clusterers.regularizers.");
            if (m_metric.getRegularizer() instanceof OptionHandler) {
                String[] regularizerOptions = ((OptionHandler) m_metric.getRegularizer()).getOptions();
                for (int i = 0; i < regularizerOptions.length; i++) {
                    options[current++] = regularizerOptions[i];
                }
            }
        }

        options[current++] = "-A";
        options[current++] = Utils.removeSubstring(m_Assigner.getClass().getName(), "weka.clusterers.assigners.");
        if (m_Assigner instanceof OptionHandler) {
            String[] assignerOptions = ((OptionHandler) m_Assigner).getOptions();
            for (int i = 0; i < assignerOptions.length; i++) {
                options[current++] = assignerOptions[i];
            }
        }

        options[current++] = "-I";
        options[current++] = Utils.removeSubstring(m_Initializer.getClass().getName(),
                "weka.clusterers.initializers.");
        if (m_Initializer instanceof OptionHandler) {
            String[] initializerOptions = ((OptionHandler) m_Initializer).getOptions();
            for (int i = 0; i < initializerOptions.length; i++) {
                options[current++] = initializerOptions[i];
            }
        }

        if (m_useMultipleMetrics) {
            options[current++] = "-U";
        }

        options[current++] = "-N";
        options[current++] = "" + getNumClusters();
        options[current++] = "-R";
        options[current++] = "" + getRandomSeed();

        options[current++] = "-l";
        options[current++] = "" + m_logTermWeight;
        options[current++] = "-r";
        options[current++] = "" + m_regularizerTermWeight;
        options[current++] = "-m";
        options[current++] = "" + m_MLweight;
        options[current++] = "-c";
        options[current++] = "" + m_CLweight;

        options[current++] = "-i";
        options[current++] = "" + m_maxIterations;
        options[current++] = "-B";
        options[current++] = "" + m_maxBlankIterations;

        options[current++] = "-O";
        options[current++] = "" + m_ClusterAssignmentsOutputFile;
        options[current++] = "-H";
        options[current++] = "" + m_ConstraintIncoherenceFile;
        options[current++] = "-V";
        options[current++] = "" + m_useTransitiveConstraints;

        while (current < options.length) {
            options[current++] = "";
        }

        return options;
    }

    /**
     * Parses a given list of options.
     * @param options the list of options as an array of strings
     * @exception Exception if an option is not supported
     *
     **/
    public void setOptions(String[] options) throws Exception {
        if (Utils.getFlag('X', options)) {
            System.out.println("Setting seedable to: false");
            setSeedable(false);
        }

        String optionString = Utils.getOption('T', options);
        if (optionString.length() != 0) {
            setTrainable(new SelectedTag(Integer.parseInt(optionString), TAGS_TRAINING));
            System.out.println("Setting trainable to: " + Integer.parseInt(optionString));
        }

        optionString = Utils.getOption('M', options);
        if (optionString.length() != 0) {
            String[] metricSpec = Utils.splitOptions(optionString);
            String metricName = metricSpec[0];
            metricSpec[0] = "";
            setMetric((LearnableMetric) Utils.forName(LearnableMetric.class, metricName, metricSpec));
            System.out.println("Setting metric to: " + metricName);
        }

        optionString = Utils.getOption('L', options);
        if (optionString.length() != 0) {
            String[] learnerSpec = Utils.splitOptions(optionString);
            String learnerName = learnerSpec[0];
            learnerSpec[0] = "";
            setMetricLearner(
                    (MPCKMeansMetricLearner) Utils.forName(MPCKMeansMetricLearner.class, learnerName, learnerSpec));
            System.out.println("Setting metricLearner to: " + m_metricLearner);
        }

        optionString = Utils.getOption('G', options);
        if (optionString.length() != 0) {
            String[] regularizerSpec = Utils.splitOptions(optionString);
            String regularizerName = regularizerSpec[0];
            regularizerSpec[0] = "";
            m_metric.setRegularizer(
                    (Regularizer) Utils.forName(Regularizer.class, regularizerName, regularizerSpec));
            System.out.println("Setting regularizer to: " + regularizerName);
        }

        optionString = Utils.getOption('A', options);
        if (optionString.length() != 0) {
            String[] assignerSpec = Utils.splitOptions(optionString);
            String assignerName = assignerSpec[0];
            assignerSpec[0] = "";
            setAssigner((MPCKMeansAssigner) Utils.forName(MPCKMeansAssigner.class, assignerName, assignerSpec));
            System.out.println("Setting assigner to: " + assignerName);
        }

        optionString = Utils.getOption('I', options);
        if (optionString.length() != 0) {
            String[] initializerSpec = Utils.splitOptions(optionString);
            String initializerName = initializerSpec[0];
            initializerSpec[0] = "";
            setInitializer((MPCKMeansInitializer) Utils.forName(MPCKMeansInitializer.class, initializerName,
                    initializerSpec));
            System.out.println("Setting initializer to: " + initializerName);
        }

        if (Utils.getFlag('U', options)) {
            setUseMultipleMetrics(true);
            System.out.println("Setting multiple metrics to: true");
        }

        optionString = Utils.getOption('N', options);
        if (optionString.length() != 0) {
            setNumClusters(Integer.parseInt(optionString));
            System.out.println("Setting numClusters to: " + m_NumClusters);
        }

        optionString = Utils.getOption('R', options);
        if (optionString.length() != 0) {
            setRandomSeed(Integer.parseInt(optionString));
            System.out.println("Setting randomSeed to: " + m_RandomSeed);
        }

        optionString = Utils.getOption('l', options);
        if (optionString.length() != 0) {
            setLogTermWeight(Double.parseDouble(optionString));
            System.out.println("Setting logTermWeight to: " + m_logTermWeight);
        }

        optionString = Utils.getOption('r', options);
        if (optionString.length() != 0) {
            setRegularizerTermWeight(Double.parseDouble(optionString));
            System.out.println("Setting regularizerTermWeight to: " + m_regularizerTermWeight);
        }

        optionString = Utils.getOption('m', options);
        if (optionString.length() != 0) {
            setMustLinkWeight(Double.parseDouble(optionString));
            System.out.println("Setting mustLinkWeight to: " + m_MLweight);
        }

        optionString = Utils.getOption('c', options);
        if (optionString.length() != 0) {
            setCannotLinkWeight(Double.parseDouble(optionString));
            System.out.println("Setting cannotLinkWeight to: " + m_CLweight);
        }

        optionString = Utils.getOption('i', options);
        if (optionString.length() != 0) {
            setMaxIterations(Integer.parseInt(optionString));
            System.out.println("Setting maxIterations to: " + m_maxIterations);
        }

        optionString = Utils.getOption('B', options);
        if (optionString.length() != 0) {
            setMaxBlankIterations(Integer.parseInt(optionString));
            System.out.println("Setting maxBlankIterations to: " + m_maxBlankIterations);
        }

        optionString = Utils.getOption('O', options);
        if (optionString.length() != 0) {
            setClusterAssignmentsOutputFile(optionString);
            System.out.println("Setting clusterAssignmentsOutputFile to: " + m_ClusterAssignmentsOutputFile);
        }

        optionString = Utils.getOption('H', options);
        if (optionString.length() != 0) {
            setConstraintIncoherenceFile(optionString);
            System.out.println("Setting m_ConstraintIncoherenceFile to: " + m_ConstraintIncoherenceFile);
        }

        if (Utils.getFlag('V', options)) {
            setUseTransitiveConstraints(false);
            System.out.println("Setting useTransitiveConstraints to: false");
        }
    }

    /**   
     * return a string describing this clusterer
     *
     * @return a description of the clusterer as a string
     */
    public String toString() {
        StringBuffer temp = new StringBuffer();
        return temp.toString();
    }

    /**
     * set the verbosity level of the clusterer
     * @param verbose messages on(true) or off (false)
     */
    public void setVerbose(boolean verbose) {
        m_verbose = verbose;
    }

    /**
     * get the verbosity level of the clusterer
     * @return messages on(true) or off (false)
     */
    public boolean getVerbose() {
        return m_verbose;
    }

    /** Set/get the use of transitive closure */
    public void setUseTransitiveConstraints(boolean useTransitiveConstraints) {
        m_useTransitiveConstraints = useTransitiveConstraints;
    }

    public boolean getUseTransitiveConstraints() {
        return m_useTransitiveConstraints;
    }

    /**
     * Turn on/off the use of per-cluster metrics
     * @param useMultipleMetrics if true, individual metrics will be used for each cluster
     */
    public void setUseMultipleMetrics(boolean useMultipleMetrics) {
        m_useMultipleMetrics = useMultipleMetrics;
    }

    /**
     * See if individual per-cluster metrics are used
     * @return true if individual metrics are used for each cluster
     */
    public boolean getUseMultipleMetrics() {
        return m_useMultipleMetrics;
    }

    /**
     * Turn on/off the use of regularization of weights
     * @param regularize, if true weights will be regularized
     */
    public void setRegularize(boolean regularize) {
        m_regularize = regularize;
    }

    /**
     * See if weights are regularized
     * @return true if weights are regularized
     */
    public boolean getRegularize() {
        return m_regularize;
    }

    /**
     * Get the value of the weight assigned to log term in the objective function
     * @return value of the weight assigned to log term in the objective function
     */
    public double getLogTermWeight() {
        return m_logTermWeight;
    }

    /**
     * Set the value of the weight assigned to log term in the objective function
     * @param logTermWeight weight assigned to log term in the objective function
     */
    public void setLogTermWeight(double logTermWeight) {
        this.m_logTermWeight = logTermWeight;
    }

    /**
     * Get the value of the weight assigned to regularizer term in the objective function
     * @return value of the weight assigned to regularizer term in the objective function
     */
    public double getRegularizerTermWeight() {
        return m_regularizerTermWeight;
    }

    /**
     * Set the value of the weight assigned to regularizer term in the objective function
     * @param regularizerTermWeight weight assigned to regularizer term in the objective function
     */
    public void setRegularizerTermWeight(double regularizerTermWeight) {
        this.m_regularizerTermWeight = regularizerTermWeight;
    }

    /**
     * Train the clusterer using specified parameters
     *
     * @param instances Instances to be used for training
     */
    public void trainClusterer(Instances instances) throws Exception {
        if (m_metric instanceof LearnableMetric) {
            if (((LearnableMetric) m_metric).getTrainable()) {
                ((LearnableMetric) m_metric).learnMetric(instances);
            } else {
                throw new Exception("Metric is not trainable");
            }
        } else {
            throw new Exception("Metric is not trainable");
        }
    }

    /** Read constraints from a file */
    public ArrayList readConstraints(String fileName) {
        ArrayList pairs = new ArrayList();

        try {
            BufferedReader reader = new BufferedReader(new FileReader(fileName));
            String s = null;
            int first = 0, second = 0, constraint = InstancePair.DONT_CARE_LINK;
            InstancePair pair = null;

            while ((s = reader.readLine()) != null) {
                StringTokenizer tokenizer = new StringTokenizer(s);
                int i = 0;
                while (tokenizer.hasMoreTokens()) {
                    String token = tokenizer.nextToken();
                    if (i == 0) {
                        first = Integer.parseInt(token);
                        //       System.out.println("First instance: " + first);
                    } else if (i == 1) {
                        second = Integer.parseInt(token);
                        //       System.out.println("Second instance: " + second);
                    } else if (i == 2) {
                        constraint = Integer.parseInt(token);
                        if (constraint < 0) {
                            if (first < second) {
                                pair = new InstancePair(first, second, InstancePair.CANNOT_LINK);
                            } else {
                                pair = new InstancePair(second, first, InstancePair.CANNOT_LINK);
                            }
                            //         System.out.println("CANNOT_LINK");
                        } else {
                            if (first < second) {
                                pair = new InstancePair(first, second, InstancePair.MUST_LINK);
                            } else {
                                pair = new InstancePair(second, first, InstancePair.CANNOT_LINK);
                            }
                            //         System.out.println("MUST_LINK");
                        }
                        if (!pairs.contains(pair)) {
                            pairs.add(pair);
                        }
                    }
                    i++;
                }
            }
        } catch (Exception e) {
            System.out.println("Problems reading from constraints file: " + e);
            e.printStackTrace();
        }

        return pairs;
    }

    /**
     * Main method for testing this class.
     *
     */

    public static void main(String[] args) {
        //testCase();
        runFromCommandLine(args);
    }

    public static void runFromCommandLine(String[] args) {
        MPCKMeans mpckmeans = new MPCKMeans();
        Instances data = null, clusterData = null;
        ArrayList labeledPairs = null;

        try {
            String optionString = Utils.getOption('D', args);
            if (optionString.length() != 0) {
                FileReader reader = new FileReader(optionString);
                data = new Instances(reader);
                System.out.println("Reading dataset: " + data.relationName());
            }

            int classIndex = data.numAttributes() - 1;
            optionString = Utils.getOption('K', args);
            if (optionString.length() != 0) {
                classIndex = Integer.parseInt(optionString);
                if (classIndex >= 0) {
                    data.setClassIndex(classIndex); // starts with 0
                    // Remove the class labels before clustering
                    clusterData = new Instances(data);
                    mpckmeans.setNumClusters(clusterData.numClasses());
                    clusterData.deleteClassAttribute();
                    System.out.println("Setting classIndex: " + classIndex);
                } else {
                    clusterData = new Instances(data);
                }
            } else {
                data.setClassIndex(classIndex); // starts with 0
                // Remove the class labels before clustering
                clusterData = new Instances(data);
                mpckmeans.setNumClusters(clusterData.numClasses());
                clusterData.deleteClassAttribute();
                System.out.println("Setting classIndex: " + classIndex);
            }

            optionString = Utils.getOption('C', args);
            if (optionString.length() != 0) {
                labeledPairs = mpckmeans.readConstraints(optionString);
                System.out.println("Reading constraints from: " + optionString);
            } else {
                labeledPairs = new ArrayList(0);
            }

            mpckmeans.setTotalTrainWithLabels(data);
            mpckmeans.setOptions(args);
            System.out.println();
            mpckmeans.buildClusterer(labeledPairs, clusterData, data, mpckmeans.getNumClusters(),
                    data.numInstances());
            mpckmeans.printClusterAssignments();

            if (mpckmeans.m_TotalTrainWithLabels.classIndex() > -1) {
                double nCorrect = 0;
                for (int i = 0; i < mpckmeans.m_TotalTrainWithLabels.numInstances(); i++) {
                    for (int j = i + 1; j < mpckmeans.m_TotalTrainWithLabels.numInstances(); j++) {
                        int cluster_i = mpckmeans.m_ClusterAssignments[i];
                        int cluster_j = mpckmeans.m_ClusterAssignments[j];
                        double class_i = (mpckmeans.m_TotalTrainWithLabels.instance(i)).classValue();
                        double class_j = (mpckmeans.m_TotalTrainWithLabels.instance(j)).classValue();
                        //         System.out.println(cluster_i + "," + cluster_j + ":" + class_i + "," + class_j);
                        if (cluster_i == cluster_j && class_i == class_j
                                || cluster_i != cluster_j && class_i != class_j) {
                            nCorrect++;
                            //        System.out.println("nCorrect:" + nCorrect);
                        }
                    }
                }
                int numInstances = mpckmeans.m_TotalTrainWithLabels.numInstances();
                double RandIndex = 100 * nCorrect / (numInstances * (numInstances - 1) / 2);
                System.err.println("Acc\t" + RandIndex);
            }

            //      if (mpckmeans.getTotalTrainWithLabels().classIndex() >= 0) {
            //    SemiSupClustererEvaluation eval = new SemiSupClustererEvaluation(mpckmeans.m_TotalTrainWithLabels,
            //                             mpckmeans.m_TotalTrainWithLabels.numClasses(),
            //                             mpckmeans.m_TotalTrainWithLabels.numClasses());
            //    eval.evaluateModel(mpckmeans, mpckmeans.m_TotalTrainWithLabels, mpckmeans.m_Instances);
            //    eval.mutualInformation();
            //    eval.pairwiseFMeasure();
            //      }
        } catch (Exception e) {
            System.out.println("Option not specified");
            e.printStackTrace();
        }
    }

    public static void testCase() {
        try {
            String dataset = new String("lowd");
            //String dataset = new String("highd");
            if (dataset.equals("lowd")) {
                //////// Low-D data

                //   String datafile = "/u/ml/data/bio/arffFromPhylo/ecoli_K12-100.arff";
                //   String datafile = "/u/sugato/weka/data/digits-0.1-389.arff";
                String datafile = "/u/sugato/weka/data/iris.arff";
                int numPairs = 200, num = 0;

                // set up the data
                FileReader reader = new FileReader(datafile);
                Instances data = new Instances(reader);

                // Make the last attribute be the class 
                int classIndex = data.numAttributes() - 1;
                data.setClassIndex(classIndex); // starts with 0
                System.out.println("ClassIndex is: " + classIndex);

                // Remove the class labels before clustering
                Instances clusterData = new Instances(data);
                clusterData.deleteClassAttribute();

                // create the pairs
                ArrayList labeledPair = InstancePair.getPairs(data, numPairs);

                System.out.println("Finished initializing constraint matrix");

                MPCKMeans mpckmeans = new MPCKMeans();
                mpckmeans.setUseMultipleMetrics(false);
                System.out.println("\nClustering the data using MPCKmeans...\n");

                WeightedEuclidean metric = new WeightedEuclidean();
                WEuclideanLearner metricLearner = new WEuclideanLearner();

                //     LearnableMetric metric = new WeightedDotP();
                //     MPCKMeansMetricLearner metricLearner = new DotPGDLearner();

                //     KL metric = new KL();
                //     KLGDLearner metricLearner = new KLGDLearner();
                //   ((KL)metric).setUseIDivergence(true);

                //   BarHillelMetric metric = new BarHillelMetric();
                //   BarHillelMetricMatlab metric = new BarHillelMetricMatlab();
                //     XingMetric metric = new XingMetric();
                //   WeightedMahalanobis metric = new WeightedMahalanobis(); 

                mpckmeans.setMetric(metric);
                mpckmeans.setMetricLearner(metricLearner);
                mpckmeans.setVerbose(false);
                mpckmeans.setRegularize(false);
                mpckmeans.setTrainable(new SelectedTag(TRAINING_INTERNAL, TAGS_TRAINING));
                mpckmeans.setSeedable(true);
                mpckmeans.buildClusterer(labeledPair, clusterData, data, data.numClasses(), data.numInstances());
                mpckmeans.getIndexClusters();
                mpckmeans.printIndexClusters();

                SemiSupClustererEvaluation eval = new SemiSupClustererEvaluation(mpckmeans.m_TotalTrainWithLabels,
                        mpckmeans.m_TotalTrainWithLabels.numClasses(),
                        mpckmeans.m_TotalTrainWithLabels.numClasses());
                eval.evaluateModel(mpckmeans, mpckmeans.m_TotalTrainWithLabels, mpckmeans.m_Instances);
                System.out.println("MI=" + eval.mutualInformation());
                System.out.print("FM=" + eval.pairwiseFMeasure());
                System.out.print("\tP=" + eval.pairwisePrecision());
                System.out.print("\tR=" + eval.pairwiseRecall());
            } else if (dataset.equals("highd")) {
                //////// Newsgroup data
                String datafile = "/u/ml/users/sugato/groupcode/weka335/data/arffFromCCS/sanitized/different-1000_sanitized.arff";
                //String datafile = "/u/ml/users/sugato/groupcode/weka335/data/20newsgroups/small-newsgroup_fromCCS.arff";
                //String datafile = "/u/ml/users/sugato/groupcode/weka335/data/20newsgroups/same-100_fromCCS.arff";

                // set up the data
                FileReader reader = new FileReader(datafile);
                Instances data = new Instances(reader);

                // Make the last attribute be the class 
                int classIndex = data.numAttributes() - 1;
                data.setClassIndex(classIndex); // starts with 0
                System.out.println("ClassIndex is: " + classIndex);

                // Remove the class labels before clustering
                Instances clusterData = new Instances(data);
                clusterData.deleteClassAttribute();

                // create the pairs
                int numPairs = 0, num = 0;
                ArrayList labeledPair = new ArrayList(numPairs);
                Random rand = new Random(42);
                System.out.println("Initializing constraint matrix:");
                while (num < numPairs) {
                    int i = (int) (data.numInstances() * rand.nextFloat());
                    int j = (int) (data.numInstances() * rand.nextFloat());
                    int first = (i < j) ? i : j;
                    int second = (i >= j) ? i : j;
                    int linkType = (data.instance(first).classValue() == data.instance(second).classValue())
                            ? InstancePair.MUST_LINK
                            : InstancePair.CANNOT_LINK;
                    InstancePair pair = new InstancePair(first, second, linkType);
                    if (first != second && !labeledPair.contains(pair)) {
                        labeledPair.add(pair);
                        //System.out.println(num + "th entry is: " + pair);
                        num++;
                    }
                }
                System.out.println("Finished initializing constraint matrix");

                MPCKMeans mpckmeans = new MPCKMeans();
                mpckmeans.setUseMultipleMetrics(false);
                System.out.println("\nClustering the highd data using MPCKmeans...\n");

                LearnableMetric metric = new WeightedDotP();
                MPCKMeansMetricLearner metricLearner = new DotPGDLearner();

                //     KL metric = new KL();
                //     KLGDLearner metricLearner = new KLGDLearner();

                mpckmeans.setMetric(metric);
                mpckmeans.setMetricLearner(metricLearner);
                mpckmeans.setVerbose(false);
                mpckmeans.setRegularize(true);
                mpckmeans.setTrainable(new SelectedTag(TRAINING_INTERNAL, TAGS_TRAINING));
                mpckmeans.setSeedable(true);
                mpckmeans.buildClusterer(labeledPair, clusterData, data, data.numClasses(), data.numInstances());
                mpckmeans.getIndexClusters();

                SemiSupClustererEvaluation eval = new SemiSupClustererEvaluation(mpckmeans.m_TotalTrainWithLabels,
                        mpckmeans.m_TotalTrainWithLabels.numClasses(),
                        mpckmeans.m_TotalTrainWithLabels.numClasses());

                mpckmeans.getMetric().resetMetric(); // Vital: to reset m_attrWeights to 1 for proper normalization
                eval.evaluateModel(mpckmeans, mpckmeans.m_TotalTrainWithLabels, mpckmeans.m_Instances);
                System.out.println("MI=" + eval.mutualInformation());
                System.out.print("FM=" + eval.pairwiseFMeasure());
                System.out.print("\tP=" + eval.pairwisePrecision());
                System.out.print("\tR=" + eval.pairwiseRecall());
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}