qupath.opencv.classify.OpenCvClassifier.java Source code

Java tutorial

Introduction

Here is the source code for qupath.opencv.classify.OpenCvClassifier.java

Source

/*-
 * #%L
 * This file is part of QuPath.
 * %%
 * Copyright (C) 2014 - 2016 The Queen's University of Belfast, Northern Ireland
 * Contact: IP Management (ipmanagement@qub.ac.uk)
 * %%
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public
 * License along with this program.  If not, see
 * <http://www.gnu.org/licenses/gpl-3.0.html>.
 * #L%
 */

package qupath.opencv.classify;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import qupath.lib.analysis.stats.RunningStatistics;
import qupath.lib.classifiers.Normalization;
import qupath.lib.classifiers.PathObjectClassifier;
import qupath.lib.measurements.MeasurementList;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.objects.classes.PathClassFactory;
import qupath.lib.plugins.parameters.ParameterList;
import qupath.lib.plugins.parameters.Parameterizable;

import org.opencv.core.CvException;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.ml.Ml;
import org.opencv.ml.StatModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Abstract base class for OpenCV classifiers.
 * 
 * Note: We cannot directly serialize an OpenCV classifier, so instead the training data is serialized and the classifier
 * rebuilt as required.  This means that potentially if a classifier is reloaded with a different version of the OpenCV library,
 * if the training algorithm has changed then there may be a different result.
 * 
 * @author Pete Bankhead
 *
 */
public abstract class OpenCvClassifier<T extends StatModel> implements PathObjectClassifier, Externalizable {

    private static final long serialVersionUID = -7974734731360344083L;

    final private static Logger logger = LoggerFactory.getLogger(OpenCvClassifier.class);

    private long timestamp = System.currentTimeMillis();
    private Normalization normalization = Normalization.NONE;
    List<PathClass> pathClasses;
    private double[] normScale;
    private double[] normOffset;
    transient T classifier;

    List<String> measurements = new ArrayList<>();
    // We can't serialize directly, so instead save all training data so classifier can be rebuilt as required
    float[] arrayTraining = null; // Array of training data
    int[] arrayResponses = null; // Array of 'responses', i.e. indices to pathClasses list

    protected OpenCvClassifier() {
    }

    /**
     * Protected method used to indicate whether any options for the classifier have been changed.
     * If this false, then updateClassifier may choose not to retrain a classifier fully if it already has a classifier
     * trained on identical data.
     * 
     * By default this always returns false (assuming that no externally-accessible parameters are involved).
     * 
     * A conservative subclass that enables options to be set may always return 'true' to force retraining in all instances.
     * 
     * A less conservative subclass that enables options to be set should check all options to see if they have changed since
     * the last time the classifier was trained, and return true or false accordingly.
     * 
     * @return
     */
    protected boolean classifierOptionsChanged() {
        return false;
    }

    @Override
    public boolean updateClassifier(final Map<PathClass, List<PathObject>> map, final List<String> measurements,
            Normalization normalization) {

        // There is a chance we don't need to retrain... to find out, cache the most important current variables
        boolean maybeSameClassifier = isValid() && this.normalization == normalization
                && !classifierOptionsChanged() && this.measurements.equals(measurements)
                && pathClasses.size() == map.size() && map.keySet().containsAll(pathClasses);

        float[] arrayTrainingPrevious = arrayTraining;
        int[] arrayResponsesPrevious = arrayResponses;

        pathClasses = new ArrayList<>(map.keySet());
        Collections.sort(pathClasses);

        int n = 0;
        for (Map.Entry<PathClass, List<PathObject>> entry : map.entrySet()) {
            n += entry.getValue().size();
        }

        // Compute running statistics for normalization
        HashMap<String, RunningStatistics> statsMap = new LinkedHashMap<>();
        for (String m : measurements)
            statsMap.put(m, new RunningStatistics());

        this.measurements.clear();
        this.measurements.addAll(measurements);
        int nMeasurements = measurements.size();
        arrayTraining = new float[n * nMeasurements];
        arrayResponses = new int[n];

        int row = 0;
        int nnan = 0;
        for (PathClass pathClass : pathClasses) {
            List<PathObject> list = map.get(pathClass);
            int classIndex = pathClasses.indexOf(pathClass);
            for (int i = 0; i < list.size(); i++) {
                MeasurementList measurementList = list.get(i).getMeasurementList();
                int col = 0;
                for (String m : measurements) {
                    double value = measurementList.getMeasurementValue(m);
                    if (Double.isNaN(value))
                        nnan++;
                    else
                        statsMap.get(m).addValue(value);
                    arrayTraining[row * nMeasurements + col] = (float) value;
                    col++;
                }
                arrayResponses[row] = classIndex;
                row++;
            }
        }

        // Normalise, if required
        if (normalization != null && normalization != Normalization.NONE) {
            logger.debug("Training classifier with normalization: {}", normalization);
            int numMeasurements = measurements.size();
            normOffset = new double[numMeasurements];
            normScale = new double[numMeasurements];
            for (int i = 0; i < numMeasurements; i++) {
                RunningStatistics stats = statsMap.get(measurements.get(i));
                if (normalization == Normalization.MEAN_VARIANCE) {
                    normOffset[i] = -stats.getMean();
                    if (stats.getStdDev() > 0)
                        normScale[i] = 1.0 / stats.getStdDev();
                } else if (normalization == Normalization.MIN_MAX) {
                    normOffset[i] = -stats.getMin();
                    if (stats.getRange() > 0)
                        normScale[i] = 1.0 / (stats.getMax() - stats.getMin());
                    else
                        normScale[i] = 1.0;
                }
            }

            // Apply normalisation
            for (int i = 0; i < arrayTraining.length; i++) {
                int k = i % numMeasurements;
                arrayTraining[i] = (float) ((arrayTraining[i] + normOffset[k]) * normScale[k]);
            }
            this.normalization = normalization;

        } else {
            logger.debug("Training classifier without normalization");
            normScale = null;
            normOffset = null;
            this.normalization = Normalization.NONE;
        }

        // Record that we have NaNs
        if (nnan > 0)
            logger.debug("Number of NaNs in training set: " + nnan);

        // Having got this far, check to see whether we really do need to retrain
        if (maybeSameClassifier) {
            if (Arrays.equals(arrayTrainingPrevious, arrayTraining)
                    && Arrays.equals(arrayResponsesPrevious, arrayResponses)) {
                logger.info("Classifier already trained with the same samples - existing classifier will be used");
                return false;
            }
        }

        createAndTrainClassifier();

        timestamp = System.currentTimeMillis();
        this.measurements = new ArrayList<>(measurements);

        return true;
    }

    protected void createAndTrainClassifier() {

        // Create the required Mats
        int nMeasurements = measurements.size();

        Mat matTraining = new Mat(arrayTraining.length / nMeasurements, nMeasurements, CvType.CV_32FC1);
        matTraining.put(0, 0, arrayTraining);
        Mat matResponses = new Mat(arrayResponses.length, 1, CvType.CV_32SC1);
        matResponses.put(0, 0, arrayResponses);

        //      // Clear any existing classifier
        //      if (classifier != null)
        //         classifier.clear();

        logger.info("Training size: " + matTraining.size());
        logger.info("Responses size: " + matResponses.size());

        // Create & train the classifier
        try {
            classifier = createClassifier();
            classifier.train(matTraining, Ml.ROW_SAMPLE, matResponses);
        } catch (CvException e) {
            // For reasons I haven't yet discerned, sometimes OpenCV throws an exception with the following message:
            // OpenCV Error: Assertion failed ((int)_sleft.size() < n && (int)_sright.size() < n) in calcDir, file /tmp/opencv320150620-1681-1u5iwhh/opencv-3.0.0/modules/ml/src/tree.cpp, line 1190
            // With one sample fewer, it can often recover... so attempt that, rather than failing miserably...
            //         logger.error("Classifier training error", e);
            logger.info("Will attempt retraining classifier with one sample fewer...");
            matTraining = matTraining.rowRange(0, matTraining.rows() - 1);
            matResponses = matResponses.rowRange(0, matResponses.rows() - 1);
            classifier = createClassifier();
            classifier.train(matTraining, Ml.ROW_SAMPLE, matResponses);
        }

        matTraining.release();
        matResponses.release();

        logger.info("Classifier trained with " + arrayResponses.length + " samples");
    }

    @Override
    public List<String> getRequiredMeasurements() {
        return new ArrayList<>(measurements);
    }

    @Override
    public Collection<PathClass> getPathClasses() {
        return new ArrayList<>(pathClasses);
    }

    @Override
    public boolean isValid() {
        return classifier != null && classifier.isTrained();
    }

    @Override
    public int classifyPathObjects(Collection<PathObject> pathObjects) {

        int counter = 0;
        float[] array = new float[measurements.size()];
        Mat samples = new Mat(1, array.length, CvType.CV_32FC1);

        Mat results = new Mat();

        for (PathObject pathObject : pathObjects) {
            MeasurementList measurementList = pathObject.getMeasurementList();
            int idx = 0;
            for (String m : measurements) {
                double value = measurementList.getMeasurementValue(m);

                if (normScale != null && normOffset != null)
                    value = (value + normOffset[idx]) * normScale[idx];

                array[idx] = (float) value;
                idx++;
            }

            samples.put(0, 0, array);

            try {
                setPredictedClass(classifier, pathClasses, samples, results, pathObject);
                //            float prediction = classifier.predict(samples);
                //            
                ////            float prediction2 = classifier.predict(samples, results, StatModel.RAW_OUTPUT);
                //            float prediction2 = classifier.predict(samples, results, StatModel.RAW_OUTPUT);
                //            
                //            pathObject.setPathClass(pathClasses.get((int)prediction), prediction2);
            } catch (Exception e) {
                pathObject.setPathClass(null);
                logger.trace("Error with samples: " + samples.dump());
                //               e.printStackTrace();
            }
            //         }
            counter++;
        }

        samples.release();
        results.release();

        return counter;
    }

    /**
     * Default prediction method.  Makes no attempt to populate results matrix or to provide probabilities.
     * (Results matrix only given as a parameter in case it is needed)
     * 
     * Subclasses may choose to override this method if they can do a better prediction, e.g. providing probabilities as well.
     * 
     * Upon returning, it is assumed that the PathClass of the PathObject will be correct, but it is not assumed that the results matrix will
     * have been updated.
     * 
     * @param classifier
     * @param pathClasses
     * @param samples
     * @param results
     * @param pathObject
     */
    protected void setPredictedClass(final T classifier, final List<PathClass> pathClasses, final Mat samples,
            final Mat results, final PathObject pathObject) {
        float prediction = classifier.predict(samples);
        PathClass pathClass = pathClasses.get((int) prediction);
        pathObject.setPathClass(pathClass);
    }

    /**
     * Create a new classifier, of whichever type the subclass desires.
     * 
     * It can be assumed that this is the classifier that will be used - without modifications - until createClassifier is called again.
     * 
     * In other words, it is permissible to cache values within createClassifier() (e.g. TermCriteria) that might
     * be import during prediction.
     * 
     * @return
     */
    protected abstract T createClassifier();

    //   @Override
    //   public int classifyPathObjects(Collection<PathObject> pathObjects) {
    //      
    //      
    //      int counter = 0;
    //      Mat samples = new Mat(1, measurements.size(), CvType.CV_32FC1);
    //      
    //      for (PathObject pathObject : pathObjects) {
    //         MeasurementList measurementList = pathObject.getMeasurementList();
    //         int idx = 0;
    //         for (String m : measurements) {
    //            double value = measurementList.getMeasurementValue(m);
    //            samples.put(0, idx, value);
    //            idx++;
    //         }
    //         
    //         float prediction = trees.predict(samples);
    //         
    ////         if (computeProbabilities) {
    ////            double prediction = svm.svm_predict_probability(model, nodes, probabilities);
    ////            int index = (int)prediction;
    ////            pathObject.setPathClass(pathClasses.get(index), probabilities[index]);
    ////         } else {
    ////            double prediction = svm.svm_predict(model, nodes);
    //            pathObject.setPathClass(pathClasses.get((int)prediction));
    ////         }
    //         counter++;
    //      }
    //            
    //      return counter;
    //   }

    @Override
    public String getDescription() {

        if (classifier == null)
            return "No classifier set!";

        StringBuilder sb = new StringBuilder();
        String mainString = getName() + (!isValid() ? " (not trained)" : "");
        ;
        sb.append("Classifier:\t").append(mainString).append("\n\n");
        sb.append("Classes:\t[");
        Iterator<PathClass> iterClasses = getPathClasses().iterator();
        while (iterClasses.hasNext()) {
            sb.append(iterClasses.next());
            if (iterClasses.hasNext())
                sb.append(", ");
            else
                sb.append("]\n\n");
        }
        sb.append("Normalization:\t").append(normalization).append("\n\n");

        if (this instanceof Parameterizable) {
            ParameterList params = ((Parameterizable) this).getParameterList();
            String paramString = ParameterList.getParameterListJSON(params, "\n  ");
            sb.append("Main parameters:\n  ").append(paramString);
            sb.append("\n\n");
        }

        List<String> measurements = getRequiredMeasurements();
        sb.append("Required measurements (").append(measurements.size()).append("):\n");
        Iterator<String> iter = getRequiredMeasurements().iterator();
        while (iter.hasNext()) {
            sb.append("    ");
            sb.append(iter.next());
            sb.append("\n");
        }

        //      sb.append("\n");
        //      sb.append(classifier.toString());

        return sb.toString();
        //      return getName() + (!isValid() ? " (not trained)" : "");
    }

    @Override
    public long getLastModifiedTimestamp() {
        return timestamp;
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        out.writeLong(2); // Version
        out.writeLong(timestamp);
        out.writeObject(pathClasses);
        out.writeObject(normScale);
        out.writeObject(normOffset);
        out.writeObject(measurements);
        out.writeObject(arrayTraining);
        out.writeObject(arrayResponses);
        out.writeObject(normalization.toString());
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {

        long version = in.readLong();
        if (version < 1 || version > 2)
            throw new IOException("Unsupported version!");

        timestamp = in.readLong();
        pathClasses = (List<PathClass>) in.readObject();
        // Ensure we have correct, single entries
        if (pathClasses != null) {
            for (int i = 0; i < pathClasses.size(); i++) {
                pathClasses.set(i, PathClassFactory.getSingletonPathClass(pathClasses.get(i)));
            }
        }

        normScale = (double[]) in.readObject();
        normOffset = (double[]) in.readObject();
        measurements = (List<String>) in.readObject();
        arrayTraining = (float[]) in.readObject();
        arrayResponses = (int[]) in.readObject();
        if (version == 2) {
            String method = (String) in.readObject();
            for (Normalization n : Normalization.values()) {
                if (n.toString().equals(method)) {
                    normalization = n;
                    break;
                }
            }
            //         normalization = Normalization.valueOf((String)in.readObject());
        }

        if (arrayTraining != null && arrayResponses != null) {
            createAndTrainClassifier();
        }

    }

}