syncleus.dann.data.video.FernEnsembleClassifier.java Source code

Java tutorial

Introduction

Here is the source code for syncleus.dann.data.video.FernEnsembleClassifier.java

Source

/**
 * Copyright 2013 Dan Oprescu
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 *     
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package syncleus.dann.data.video;

import java.util.List;
import java.util.Properties;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import syncleus.dann.data.video.TLDUtil.Pair;
import syncleus.dann.data.video.TLDUtil.RNG;

public class FernEnsembleClassifier {
    ParamsClassifiers params;
    private Fern[] ferns;

    //   final List<Mat> pExamples = new ArrayList<Mat>();
    //   final List<Mat> nExamples = new ArrayList<Mat>();

    public FernEnsembleClassifier() {
    }

    public FernEnsembleClassifier(Properties props) {
        params = new ParamsClassifiers(props);
    }

    public void init(Size[] scales, RNG rng) {
        ferns = new Fern[params.numFerns];
        for (int i = 0; i < ferns.length; i++) {
            ferns[i] = new Fern(params.numFeaturesPerFern, scales, rng);
        }
    }

    /**
     * Updates the POSITIVE Ferns
     * The threshold for Positive results has to be > to the average of negative posteriors
     */
    void evaluateThreshold(final List<Pair<int[], Boolean>> nFernsTest) {
        for (Pair<int[], Boolean> fern : nFernsTest) {
            // here we know/hope that fern.second is always FALSE, as they are all NEGATIVE examples
            final double averagePosterior = averagePosterior(fern.first);
            if (averagePosterior > params.pos_thr_fern) {
                params.pos_thr_fern = averagePosterior;
            }
        }
    }

    void trainF(final List<Pair<int[], Boolean>> ferns, int resample) {
        for (int i = 0; i < resample; i++) {
            for (Pair<int[], Boolean> fern : ferns) {
                // the THRESHOLDS are here to make sure we don't increase/decrease the probabilities beyond given limits, to give other hashCodes a chance
                if (fern.second) { // if it's a positive fern
                    if (averagePosterior(fern.first) <= params.pos_thr_fern) {
                        updatePosteriors(fern.first, true);
                    }
                } else if (averagePosterior(fern.first) >= params.neg_thr_fern) {
                    updatePosteriors(fern.first, false);
                }
            }
        }
    }

    private void updatePosteriors(final int[] fernsHashCodes, boolean positive) {
        assert (params.numFerns == fernsHashCodes.length);

        for (int fern = 0; fern < fernsHashCodes.length; fern++) {
            ferns[fern].addCountUpdatePosteriors(fernsHashCodes[fern], positive);
        }
    }

    /**
     * @return conf
     */
    double averagePosterior(final int[] fernsHashCodes) {
        assert (params.numFerns == fernsHashCodes.length);

        double result = 0;
        for (int fern = 0; fern < fernsHashCodes.length; fern++) {
            result += ferns[fern].posteriorProbabilities[fernsHashCodes[fern]];
        }
        return result / fernsHashCodes.length;
    }

    /**
     * The numbers in this array can be up to 2^params.structSize as we shift left once of each feature
     */
    int[] getAllFernsHashCodes(final Mat patch, int scaleIdx) {
        final int[] result = new int[ferns.length];
        final byte[] imageData = TLDUtil.getByteArray(patch);
        final int cols = patch.cols();
        for (int fern = 0; fern < ferns.length; fern++) {
            result[fern] = ferns[fern].calculateHashCode(scaleIdx, imageData, cols);
        }

        return result;
    }

    static class Fern {
        private final Feature[][] features; // per scaleIdx
        // per HASHCODE
        final double[] posteriorProbabilities; // the probability that it's our image
        final long[] nCounter; // the number of NEGATIVE patches
        final long[] pCounter; // the number of POSITIVE patches

        Fern(int featuresPerFern, Size[] scales, RNG rng) {
            // 1. Define random features
            features = new Feature[scales.length][featuresPerFern];
            for (int i = 0; i < featuresPerFern; i++) {
                final float x1f = rng.nextFloat();
                final float y1f = rng.nextFloat();
                final float x2f = rng.nextFloat();
                final float y2f = rng.nextFloat();
                for (int s = 0; s < scales.length; s++) {
                    final int x1 = (int) (x1f * scales[s].width);
                    final int y1 = (int) (y1f * scales[s].height);
                    final int x2 = (int) (x2f * scales[s].width);
                    final int y2 = (int) (y2f * scales[s].height);
                    features[s][i] = new Feature(x1, y1, x2, y2);
                }
            }

            // 2. Initialise Posteriors
            final int MAX_HASHCODE = (int) Math.pow(2d, featuresPerFern);
            posteriorProbabilities = new double[MAX_HASHCODE];
            pCounter = new long[MAX_HASHCODE];
            nCounter = new long[MAX_HASHCODE];
        }

        void addCountUpdatePosteriors(int fernHashCode, boolean positive) {
            if (positive) {
                pCounter[fernHashCode]++;
            } else {
                nCounter[fernHashCode]++;
            }

            posteriorProbabilities[fernHashCode] = ((double) pCounter[fernHashCode])
                    / (pCounter[fernHashCode] + nCounter[fernHashCode]);
        }

        int calculateHashCode(int scaleIdx, byte[] imageData, int cols) {
            int fernHashCode = 0;
            for (Feature feature : features[scaleIdx]) {
                // compare returns 0 / 1 and 
                fernHashCode = (fernHashCode << 1) + feature.compare(imageData, cols);
            }

            return fernHashCode;
        }
    }

    /**
     * A Feature is a pixel Comparison, between 2 points.
     */
    private static class Feature {
        private final int x1, y1, x2, y2;

        public Feature(int x1, int y1, int x2, int y2) {
            this.x1 = x1;
            this.y1 = y1;
            this.x2 = x2;
            this.y2 = y2;
        }

        /**
         * Simply compares the brightness between the 2 points defining this Feature
         * Assumes channels = 1 (hence only multiplying with cols).
         */
        public int compare(final byte[] patch, final int cols) {
            final int pos1 = y1 * cols + x1;
            final int pos2 = y2 * cols + x2;
            if (pos1 >= patch.length || pos2 >= patch.length) {
                System.out.println("Bad patch of size: " + patch.length + " cols: " + cols + " to compare Feature: "
                        + this.toString());
                return 0;
            }

            final boolean boolRes = patch[pos1] > patch[pos2];
            return boolRes ? 1 : 0;
        }

        @Override
        public String toString() {
            return x1 + ", " + y1 + ", " + x2 + ", " + y2;
        }
    }

    int getNumFerns() {
        return params.numFerns;
    }

    double getFernPosThreshold() {
        return params.pos_thr_fern;
    }

    // TODO use to display the positive examples used by learning...
    //   public Mat getPosExamples(){
    //      if(pExamples == null || pExamples.size() == 0) return null;
    //         
    //      final int exRows = pExamples.get(0).rows();
    //      final int exCols = pExamples.get(0).cols();
    //      
    //      // create a Matrix that can contain vertically all the positive examples
    //      final Mat result = new Mat(pExamples.size() * exRows, exCols, CvType.CV_8U);
    //      Imgproc.
    //   }
}