org.apache.mahout.classifier.rbm.training.RBMClassifierTrainingJob.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.rbm.training.RBMClassifierTrainingJob.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.classifier.rbm.training;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.rbm.RBMClassifier;
import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM;
import org.apache.mahout.classifier.rbm.model.RBMModel;
import org.apache.mahout.classifier.rbm.model.SimpleRBM;
import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * The Class RBMClassifierTrainingJob.
 */
public class RBMClassifierTrainingJob extends AbstractJob {

    /** The Constant WEIGHT_UPDATES. */
    public static final String WEIGHT_UPDATES = "weightupdates";

    /** The Constant logger. */
    private static final Logger logger = LoggerFactory.getLogger(RBMClassifierTrainingJob.class);

    /** The last update which is needed for use of the momentum. */
    Matrix[] lastUpdate;

    /** The rbm classifier. */
    RBMClassifier rbmCl = null;

    /** The number of iterations (epochs). */
    int epochs;

    /** The learningrate. */
    double learningrate;

    /** The momentum used. */
    double momentum;

    /** monitor if true. */
    boolean monitor;

    /** initial biases if true. */
    boolean initbiases;

    /** train greedy if true. */
    boolean greedy;

    /** finetune if true. */
    boolean finetuning;

    /** The batches to train on. */
    Path[] batches = null;

    /** The labelcount. */
    int labelcount;

    /** The nr gibbs sampling. */
    int nrGibbsSampling;

    /** The rbm nr to train. */
    int rbmNrtoTrain;

    /**
     * The main method.
     *
     * @param args the arguments
     * @throws Exception the exception
     */
    public static void main(String[] args) throws Exception {
        ToolRunner.run(new Configuration(), new RBMClassifierTrainingJob(), args);
    }

    /* (non-Javadoc)
     * @see org.apache.hadoop.util.Tool#run(java.lang.String[])
     */
    @Override
    public int run(String[] args) throws Exception {
        addInputOption();
        addOutputOption();
        addOption("epochs", "e", "number of training epochs through the trainingset", true);
        addOption("structure", "s", "comma-separated list of layer sizes", false);
        addOption("labelcount", "lc", "total count of labels existent in the training set", true);
        addOption("learningrate", "lr", "learning rate at the beginning of training", "0.005");
        addOption("momentum", "m", "momentum of learning at the beginning", "0.5");
        addOption("rbmnr", "nr", "rbm to train, < 0 means train all", "-1");
        addOption("nrgibbs", "gn", "number of gibbs sampling used in contrastive divergence", "5");
        addOption(new DefaultOptionBuilder().withLongName(DefaultOptionCreator.MAPREDUCE_METHOD).withRequired(false)
                .withDescription("Run training with map/reduce").withShortName("mr").create());
        addOption(new DefaultOptionBuilder().withLongName("nogreedy").withRequired(false)
                .withDescription("Don't run greedy pre training").withShortName("ng").create());
        addOption(new DefaultOptionBuilder().withLongName("nofinetuning").withRequired(false)
                .withDescription("Don't run fine tuning at the end").withShortName("nf").create());
        addOption(new DefaultOptionBuilder().withLongName("nobiases").withRequired(false)
                .withDescription("Don't initialize biases").withShortName("nb").create());
        addOption(new DefaultOptionBuilder().withLongName("monitor").withRequired(false)
                .withDescription("If present, errors can be monitored in cosole").withShortName("mon").create());
        addOption(DefaultOptionCreator.overwriteOption().create());

        Map<String, String> parsedArgs = parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }

        Path input = getInputPath();
        Path output = getOutputPath();
        FileSystem fs = FileSystem.get(output.toUri(), getConf());
        labelcount = Integer.parseInt(getOption("labelcount"));

        boolean local = !hasOption("mapreduce");
        monitor = hasOption("monitor");
        initbiases = !hasOption("nobiases");
        finetuning = !hasOption("nofinetuning");
        greedy = !hasOption("nogreedy");

        if (fs.isFile(input))
            batches = new Path[] { input };
        else {
            FileStatus[] stati = fs.listStatus(input);
            batches = new Path[stati.length];
            for (int i = 0; i < stati.length; i++) {
                batches[i] = stati[i].getPath();
            }
        }

        epochs = Integer.valueOf(getOption("epochs"));
        learningrate = Double.parseDouble(getOption("learningrate"));
        momentum = Double.parseDouble(getOption("momentum"));
        rbmNrtoTrain = Integer.parseInt(getOption("rbmnr"));
        nrGibbsSampling = Integer.parseInt(getOption("nrgibbs"));

        boolean initialize = hasOption(DefaultOptionCreator.OVERWRITE_OPTION) || !fs.exists(output)
                || fs.listStatus(output).length <= 0;

        if (initialize) {
            String structure = getOption("structure");
            if (structure == null || structure.isEmpty())
                return -1;

            String[] layers = structure.split(",");
            if (layers.length < 2) {
                return -1;
            }

            int[] actualLayerSizes = new int[layers.length];
            for (int i = 0; i < layers.length; i++) {
                actualLayerSizes[i] = Integer.parseInt(layers[i]);
            }

            rbmCl = new RBMClassifier(labelcount, actualLayerSizes);
            logger.info("New model initialized!");
        } else {
            rbmCl = RBMClassifier.materialize(output, getConf());
            logger.info("Model found and materialized!");
        }

        HadoopUtil.setSerializations(getConf());
        lastUpdate = new Matrix[rbmCl.getDbm().getRbmCount()];

        if (initbiases) {
            //init biases!
            Vector biases = null;
            int counter = 0;
            for (Path batch : batches) {
                for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(
                        batch, getConf())) {
                    if (biases == null)
                        biases = record.getSecond().get().clone();
                    else
                        biases.plus(record.getSecond().get());
                    counter++;
                }
            }
            if (biases == null) {
                logger.info("No training data found!");
                return -1;
            }

            rbmCl.getDbm().getLayer(0).setBiases(biases.divide(counter));
            logger.info("Biases initialized");
        }

        //greedy pre training with gradually decreasing learningrates
        if (greedy) {
            if (!local)
                rbmCl.serialize(output, getConf());

            double tempLearningrate = learningrate;
            if (rbmNrtoTrain < 0)
                //train all rbms
                for (int rbmNr = 0; rbmNr < rbmCl.getDbm().getRbmCount(); rbmNr++) {
                    tempLearningrate = learningrate;

                    //double weights if dbm was materialized, because it was halved after greedy pretraining
                    if (!initialize && rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) {
                        ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNr)).setWeightMatrix(
                                ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNr)).getWeightMatrix().times(2));
                    }

                    for (int j = 0; j < epochs; j++) {
                        logger.info("Greedy training, epoch " + (j + 1) + "\nCurrent learningrate: "
                                + tempLearningrate);
                        for (int b = 0; b < batches.length; b++) {
                            tempLearningrate -= learningrate / (epochs * batches.length + epochs);
                            if (local) {
                                if (!trainGreedySeq(rbmNr, batches[b], j, tempLearningrate))
                                    return -1;
                            } else if (!trainGreedyMR(rbmNr, batches[b], j, tempLearningrate))
                                return -1;
                            if (monitor && (batches.length > 19) && (b + 1) % (batches.length / 20) == 0)
                                logger.info(rbmNr + "-RBM: " + Math.round(((double) b + 1) / batches.length * 100.0)
                                        + "% in epoch done!");
                        }
                        logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training on rbm number "
                                + rbmNr + " is done!");

                        if (monitor) {
                            double error = rbmError(batches[0], rbmNr);
                            logger.info(
                                    "Average reconstruction error on batch " + batches[0].getName() + ": " + error);
                        }

                        rbmCl.serialize(output, getConf());
                    }

                    //weight normalization to avoid double counting
                    if (rbmNr > 0 && rbmNr < rbmCl.getDbm().getRbmCount() - 1) {
                        ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix(
                                ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5));
                    }
                }
            else {
                //double weights if dbm was materialized, because it was halved after greedy pretraining
                if (!initialize && rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) {
                    ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix(
                            ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(2));
                }
                //train just wanted rbm
                for (int j = 0; j < epochs; j++) {
                    logger.info(
                            "Greedy training, epoch " + (j + 1) + "\nCurrent learningrate: " + tempLearningrate);
                    for (int b = 0; b < batches.length; b++) {
                        tempLearningrate -= learningrate / (epochs * batches.length + epochs);
                        if (local) {
                            if (!trainGreedySeq(rbmNrtoTrain, batches[b], j, tempLearningrate))
                                return -1;
                        } else if (!trainGreedyMR(rbmNrtoTrain, batches[b], j, tempLearningrate))
                            return -1;
                        if (monitor && (batches.length > 19) && (b + 1) % (batches.length / 20) == 0)
                            logger.info(rbmNrtoTrain + "-RBM: "
                                    + Math.round(((double) b + 1) / batches.length * 100.0) + "% in epoch done!");
                    }
                    logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training is done!");

                    if (monitor) {
                        double error = rbmError(batches[0], rbmNrtoTrain);
                        logger.info("Average reconstruction error on batch " + batches[0].getName() + ": " + error);
                    }
                }

                //weight normalization to avoid double counting
                if (rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) {
                    ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix(
                            ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5));
                }
            }

            rbmCl.serialize(output, getConf());
            logger.info("Pretraining done and model written to output");
        }

        if (finetuning) {
            DeepBoltzmannMachine multiLayerDbm = null;

            double tempLearningrate = learningrate;
            //finetuning job
            for (int j = 0; j < epochs; j++) {
                for (int b = 0; b < batches.length; b++) {
                    multiLayerDbm = rbmCl.initializeMultiLayerNN();
                    logger.info("Finetuning on batch " + batches[b].getName() + "\nCurrent learningrate: "
                            + tempLearningrate);
                    tempLearningrate -= learningrate / (epochs * batches.length + epochs);
                    if (local) {
                        if (!finetuneSeq(batches[b], j, multiLayerDbm, tempLearningrate))
                            return -1;
                    } else if (!fintuneMR(batches[b], j, tempLearningrate))
                        return -1;
                    logger.info("Finetuning: " + Math.round(((double) b + 1) / batches.length * 100.0)
                            + "% in epoch done!");
                }
                logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training is done!");

                if (monitor) {
                    double error = feedForwardError(multiLayerDbm, batches[0]);
                    logger.info("Average discriminative error on batch " + batches[0].getName() + ": " + error);
                }
            }
            //final serialization
            rbmCl.serialize(output, getConf());
            logger.info("RBM finetuning done and model written to output");
        }

        if (executor != null)
            executor.shutdownNow();

        return 0;
    }

    /**
     * The Class BackpropTrainingThread is the callable thread for the local backprop task.
     */
    class BackpropTrainingThread implements Callable<Matrix[]> {

        /** The dbm. */
        private DeepBoltzmannMachine dbm;

        /** The input. */
        private Vector input;

        /** The label. */
        private Vector label;

        /** The trainer. */
        private BackPropTrainer trainer;

        /**
         * Instantiates a new backprop training thread.
         *
         * @param dbm the dbm
         * @param label the label
         * @param input the input
         * @param trainer the trainer
         */
        public BackpropTrainingThread(DeepBoltzmannMachine dbm, Vector label, Vector input,
                BackPropTrainer trainer) {
            this.dbm = dbm;
            this.label = label;
            this.input = input;
            this.trainer = trainer;
        }

        /* (non-Javadoc)
         * @see java.util.concurrent.Callable#call()
         */
        @Override
        public Matrix[] call() throws Exception {
            Matrix[] result = trainer.calculateWeightUpdates(dbm, input, label);
            Matrix[] weightUpdates = new Matrix[dbm.getRbmCount() - 1];

            //write for each RBM i (key, number of rbm) the result and put together the last two
            //matrices since they refer to just one labeled rbm, which was split to two for the training
            for (int i = 0; i < result.length - 1; i++) {
                if (i == result.length - 2) {
                    weightUpdates[i] = new DenseMatrix(result[i].rowSize() + result[i + 1].columnSize(),
                            result[i].columnSize());
                    for (int j = 0; j < weightUpdates[i].rowSize(); j++)
                        for (int k = 0; k < weightUpdates[i].columnSize(); k++) {
                            if (j < result[i].rowSize())
                                weightUpdates[i].set(j, k, result[i].get(j, k));
                            else
                                weightUpdates[i].set(j, k, result[i + 1].get(k, j - result[i].rowSize()));
                        }
                } else
                    weightUpdates[i] = result[i];
            }

            return weightUpdates;
        }

    }

    /** The backprop training tasks. */
    List<BackpropTrainingThread> backpropTrainingTasks;

    /**
     * Finetune locally.
     *
     * @param batch the batch
     * @param iteration the iteration
     * @param multiLayerDbm the multilayer dbm
     * @param learningrate the learningrate
     * @return true, if successful
     * @throws InterruptedException the interrupted exception
     * @throws ExecutionException the execution exception
     */
    private boolean finetuneSeq(Path batch, int iteration, DeepBoltzmannMachine multiLayerDbm, double learningrate)
            throws InterruptedException, ExecutionException {
        Vector label = new DenseVector(labelcount);
        Map<Integer, Matrix> updates = new HashMap<Integer, Matrix>();
        int batchsize = 0;
        //maximum number of threads that are used, I think 20 is ok
        int threadCount = 20;
        Matrix[] weightUpdates;
        //initialize the tasks, which are run by the executor
        if (backpropTrainingTasks == null)
            backpropTrainingTasks = new ArrayList<BackpropTrainingThread>();
        //initialize the executor if not already done
        if (executor == null)
            executor = Executors.newFixedThreadPool(threadCount);

        for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch,
                getConf())) {
            for (int i = 0; i < label.size(); i++)
                label.setQuick(i, 0);

            label.set(record.getFirst().get(), 1);

            BackPropTrainer trainer = new BackPropTrainer(learningrate);

            //prepare the tasks
            if (backpropTrainingTasks.size() < threadCount)
                backpropTrainingTasks.add(new BackpropTrainingThread(multiLayerDbm.clone(), label.clone(),
                        record.getSecond().get(), trainer));
            else {
                backpropTrainingTasks.get(batchsize % threadCount).input = record.getSecond().get();
                backpropTrainingTasks.get(batchsize % threadCount).label = label.clone();
                if (batchsize < threadCount) {
                    backpropTrainingTasks.get(batchsize % threadCount).dbm = multiLayerDbm.clone();
                }
            }

            //run the tasks and save results
            if (batchsize % threadCount == threadCount - 1) {
                List<Future<Matrix[]>> futureUpdates = executor.invokeAll(backpropTrainingTasks);
                for (int i = 0; i < futureUpdates.size(); i++) {
                    weightUpdates = futureUpdates.get(i).get();
                    for (int j = 0; j < weightUpdates.length; j++) {
                        if (updates.containsKey(j))
                            updates.put(j, weightUpdates[j].plus(updates.get(j)));
                        else
                            updates.put(j, weightUpdates[j]);
                    }
                }
            }

            batchsize++;
        }

        //run remaining tasks
        if (batchsize % 20 != 0) {
            List<Future<Matrix[]>> futureUpdates = executor
                    .invokeAll(backpropTrainingTasks.subList(0, (batchsize - 1) % 20));
            for (int i = 0; i < futureUpdates.size(); i++) {
                weightUpdates = futureUpdates.get(i).get();
                for (int j = 0; j < weightUpdates.length; j++) {
                    if (updates.containsKey(j))
                        updates.put(j, weightUpdates[j].plus(updates.get(j)));
                    else
                        updates.put(j, weightUpdates[j]);
                }
            }
        }

        updateRbmCl(batchsize, (iteration == 0) ? 0 : momentum, updates);

        return true;
    }

    /**
     * Fintune using map/reduce.
     *
     * @param batch the batch
     * @param iteration the iteration
     * @param learningrate the learningrate
     * @return true, if successful
     * @throws IOException Signals that an I/O exception has occurred.
     * @throws InterruptedException the interrupted exception
     * @throws ClassNotFoundException the class not found exception
     */
    private boolean fintuneMR(Path batch, int iteration, double learningrate)
            throws IOException, InterruptedException, ClassNotFoundException {
        //prepare and run finetune job
        long batchsize;
        HadoopUtil.delete(getConf(), getTempPath(WEIGHT_UPDATES));
        HadoopUtil.cacheFiles(getOutputPath(), getConf());

        Job trainDBM = prepareJob(batch, getTempPath(WEIGHT_UPDATES), SequenceFileInputFormat.class,
                DBMBackPropTrainingMapper.class, IntWritable.class, MatrixWritable.class,
                DBMBackPropTrainingReducer.class, IntWritable.class, MatrixWritable.class,
                SequenceFileOutputFormat.class);
        trainDBM.getConfiguration().set("labelcount", String.valueOf(labelcount));
        trainDBM.getConfiguration().set("learningrate", String.valueOf(learningrate));

        trainDBM.setCombinerClass(DBMBackPropTrainingReducer.class);

        if (!trainDBM.waitForCompletion(true))
            return false;

        batchsize = trainDBM.getCounters().findCounter(DBMBackPropTrainingMapper.BATCHES.SIZE).getValue();

        changeAndSaveModel(getOutputPath(), batchsize, (iteration == 0) ? 0 : momentum);
        return true;
    }

    /**
     * The Class GreedyTrainingThread.
     */
    class GreedyTrainingThread implements Callable<Matrix> {

        /** The dbm. */
        private DeepBoltzmannMachine dbm;

        /** The input. */
        private Vector input;

        /** The label. */
        private Vector label;

        /** The trainer. */
        private CDTrainer trainer;

        /** The rbm nr to train. */
        int rbmNr;

        /**
         * Instantiates a new greedy training thread.
         *
         * @param dbm the dbm
         * @param label the label
         * @param input the input
         * @param trainer the trainer
         * @param rbmNr the rbm nr
         */
        public GreedyTrainingThread(DeepBoltzmannMachine dbm, Vector label, Vector input, CDTrainer trainer,
                int rbmNr) {
            this.dbm = dbm;
            this.label = label;
            this.input = input;
            this.trainer = trainer;
            this.rbmNr = rbmNr;
        }

        /* (non-Javadoc)
         * @see java.util.concurrent.Callable#call()
         */
        @Override
        public Matrix call() throws Exception {
            Matrix updates = null;

            dbm.getRBM(0).getVisibleLayer().setActivations(input);
            for (int i = 0; i < rbmNr; i++) {
                //double the bottom up connection for initialization
                dbm.getRBM(i).exciteHiddenLayer(2, false);
                if (i == rbmNr - 1)
                    //probabilities as activation for the data the rbm should train on
                    dbm.getRBM(i).getHiddenLayer().setProbabilitiesAsActivation();
                else
                    dbm.getRBM(i).getHiddenLayer().updateNeurons();
            }

            if (rbmNr == dbm.getRbmCount() - 1) {
                ((LabeledSimpleRBM) dbm.getRBM(rbmNr)).getSoftmaxLayer().setActivations(label);
                updates = trainer.calculateWeightUpdates((LabeledSimpleRBM) dbm.getRBM(rbmNr), true, false);
            } else {
                updates = trainer.calculateWeightUpdates((SimpleRBM) dbm.getRBM(rbmNr), false, rbmNr == 0);
            }
            return updates;
        }

    }

    /** The executor. */
    private ExecutorService executor;

    /** The greedy training tasks. */
    List<GreedyTrainingThread> greedyTrainingTasks;

    /**
     * Train greedy seq.
     *
     * @param rbmNr the rbm nr
     * @param batch the batch
     * @param iteration the iteration
     * @param learningrate the learningrate
     * @return true, if successful
     * @throws InterruptedException the interrupted exception
     * @throws ExecutionException the execution exception
     */
    private boolean trainGreedySeq(int rbmNr, Path batch, int iteration, double learningrate)
            throws InterruptedException, ExecutionException {
        int batchsize = 0;
        DeepBoltzmannMachine dbm = rbmCl.getDbm();
        Vector label = new DenseVector(labelcount);
        Matrix updates = null;
        //number of threads running the tasks
        int threadCount = 20;
        if (executor == null)
            executor = Executors.newFixedThreadPool(threadCount);
        if (greedyTrainingTasks == null)
            greedyTrainingTasks = new ArrayList<RBMClassifierTrainingJob.GreedyTrainingThread>();

        for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch,
                getConf())) {
            CDTrainer trainer = new CDTrainer(learningrate, nrGibbsSampling);
            label.assign(0);
            label.set(record.getFirst().get(), 1);
            //prepare the tasks
            if (greedyTrainingTasks.size() < threadCount)
                greedyTrainingTasks.add(new GreedyTrainingThread(dbm.clone(), label.clone(),
                        record.getSecond().get(), trainer, rbmNr));
            else {
                greedyTrainingTasks.get(batchsize % threadCount).input = record.getSecond().get();
                greedyTrainingTasks.get(batchsize % threadCount).label = label.clone();
                if (batchsize < threadCount) {
                    greedyTrainingTasks.get(batchsize % threadCount).dbm = dbm.clone();
                    greedyTrainingTasks.get(batchsize % threadCount).rbmNr = rbmNr;
                }
            }
            //run tasks
            if (batchsize % threadCount == threadCount - 1) {
                List<Future<Matrix>> futureUpdates = executor.invokeAll(greedyTrainingTasks);
                for (int i = 0; i < futureUpdates.size(); i++) {
                    if (updates == null)
                        updates = futureUpdates.get(i).get();
                    else
                        updates = updates.plus(futureUpdates.get(i).get());
                }
            }

            batchsize++;
        }

        //run remaining tasks
        if (batchsize % 20 != 0) {
            List<Future<Matrix>> futureUpdates = executor
                    .invokeAll(greedyTrainingTasks.subList(0, (batchsize - 1) % 20));
            for (int i = 0; i < futureUpdates.size(); i++) {
                if (updates == null)
                    updates = futureUpdates.get(i).get();
                else
                    updates = updates.plus(futureUpdates.get(i).get());
            }
        }

        Map<Integer, Matrix> updateMap = new HashMap<Integer, Matrix>();
        updateMap.put(rbmNr, updates);
        updateRbmCl(batchsize, (lastUpdate[rbmNr] == null) ? 0 : momentum, updateMap);

        return true;
    }

    /**
     * Train greedy mr.
     *
     * @param rbmNr the rbm nr
     * @param batch the batch
     * @param iteration the iteration
     * @param learningrate the learningrate
     * @return true, if successful
     * @throws IOException Signals that an I/O exception has occurred.
     * @throws InterruptedException the interrupted exception
     * @throws ClassNotFoundException the class not found exception
     */
    private boolean trainGreedyMR(int rbmNr, Path batch, int iteration, double learningrate)
            throws IOException, InterruptedException, ClassNotFoundException {
        //run greedy pretraining as map reduce job
        long batchsize;
        HadoopUtil.delete(getConf(), getTempPath(WEIGHT_UPDATES));
        HadoopUtil.cacheFiles(getOutputPath(), getConf());

        Job trainRBM = prepareJob(batch, getTempPath(WEIGHT_UPDATES), SequenceFileInputFormat.class,
                RBMGreedyPreTrainingMapper.class, IntWritable.class, MatrixWritable.class,
                RBMGreedyPreTrainingReducer.class, IntWritable.class, MatrixWritable.class,
                SequenceFileOutputFormat.class);
        trainRBM.getConfiguration().set("rbmNr", String.valueOf(rbmNr));
        trainRBM.getConfiguration().set("labelcount", String.valueOf(labelcount));
        trainRBM.getConfiguration().set("learningrate", String.valueOf(learningrate));
        trainRBM.getConfiguration().set("nrGibbsSampling", String.valueOf(nrGibbsSampling));

        trainRBM.setCombinerClass(RBMGreedyPreTrainingReducer.class);

        if (!trainRBM.waitForCompletion(true))
            return false;

        batchsize = trainRBM.getCounters().findCounter(RBMGreedyPreTrainingMapper.BATCH.SIZE).getValue();

        changeAndSaveModel(getOutputPath(), batchsize, (lastUpdate[rbmNr] == null) ? 0 : momentum);

        return true;
    }

    /**
     * calculate classifiers error after 1 iteration of sampling.
     *
     * @param batch the batch
     * @return the error
     */
    @SuppressWarnings("unused")
    private double classifierError(Path batch) {
        double error = 0;
        int counter = 0;
        Vector scores;
        for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch,
                getConf())) {
            scores = rbmCl.classify(record.getSecond().get(), 1);
            error += 1 - scores.get(record.getFirst().get());
            counter++;
        }
        error /= counter;
        return error;
    }

    /**
     * Calculates error of fann.
     *
     * @param feedForwardNet the feed forward net
     * @param batch the batch
     * @return the error
     */
    private double feedForwardError(DeepBoltzmannMachine feedForwardNet, Path batch) {
        double error = 0;
        int counter = 0;
        RBMModel currentRBM = null;
        for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch,
                getConf())) {
            feedForwardNet.getRBM(0).getVisibleLayer().setActivations(record.getSecond().get());
            for (int i = 0; i < feedForwardNet.getRbmCount(); i++) {
                currentRBM = feedForwardNet.getRBM(i);
                currentRBM.exciteHiddenLayer(1, false);
                currentRBM.getHiddenLayer().setProbabilitiesAsActivation();
            }
            error += 1 - currentRBM.getHiddenLayer().getActivations().get(record.getFirst().get());
            counter++;
        }
        error /= counter;
        return error;
    }

    /**
     * Rbms reconstruction error.
     *
     * @param batch the batch
     * @param rbmNr the rbm nr
     * @return the error
     */
    private double rbmError(Path batch, int rbmNr) {
        DeepBoltzmannMachine dbm = rbmCl.getDbm();
        Vector label = new DenseVector(
                ((LabeledSimpleRBM) dbm.getRBM(dbm.getRbmCount() - 1)).getSoftmaxLayer().getNeuronCount());

        double error = 0;
        int counter = 0;

        for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch,
                getConf())) {
            dbm.getRBM(0).getVisibleLayer().setActivations(record.getSecond().get());
            for (int i = 0; i < rbmNr; i++) {
                //double the bottom up connection for initialization
                dbm.getRBM(i).exciteHiddenLayer(2, false);
                if (i == rbmNr - 1)
                    dbm.getRBM(i).getHiddenLayer().setProbabilitiesAsActivation();
                else
                    dbm.getRBM(i).getHiddenLayer().updateNeurons();
            }
            if (dbm.getRBM(rbmNr) instanceof LabeledSimpleRBM) {
                label.assign(0);
                label.set(record.getFirst().get(), 1);
                ((LabeledSimpleRBM) dbm.getRBM(rbmNr)).getSoftmaxLayer().setActivations(label);
            }
            error += dbm.getRBM(rbmNr).getReconstructionError();
            counter++;
        }

        error /= counter;
        return error;
    }

    /**
     * Change and save model.
     *
     * @param output the output
     * @param batchsize the batchsize
     * @param momentum the momentum
     * @throws IOException Signals that an I/O exception has occurred.
     */
    private void changeAndSaveModel(Path output, long batchsize, double momentum) throws IOException {
        Map<Integer, Matrix> updates = new HashMap<Integer, Matrix>();

        for (Pair<IntWritable, MatrixWritable> record : new SequenceFileDirIterable<IntWritable, MatrixWritable>(
                getTempPath(WEIGHT_UPDATES), PathType.LIST, PathFilters.partFilter(), getConf())) {

            if (!updates.containsKey(record.getFirst().get()))
                updates.put(record.getFirst().get(), record.getSecond().get());
            else
                updates.put(record.getFirst().get(),
                        record.getSecond().get().plus(updates.get(record.getFirst().get())));
        }

        updateRbmCl(batchsize, momentum, updates);

        //serialization for mappers to have actual version of the dbm
        rbmCl.serialize(output, getConf());
    }

    /**
     * Update rbm classifier with given updates.
     *
     * @param batchsize the batchsize
     * @param momentum the momentum
     * @param updates the updates
     */
    private void updateRbmCl(long batchsize, double momentum, Map<Integer, Matrix> updates) {
        for (Integer rbmNr : updates.keySet()) {
            if (momentum > 0)
                updates.put(rbmNr, (updates.get(rbmNr).divide(batchsize).times(1 - momentum))
                        .plus(lastUpdate[rbmNr].times(momentum)));
            else
                updates.put(rbmNr, updates.get(rbmNr).divide(batchsize));

            if (rbmNr < rbmCl.getDbm().getRbmCount() - 1) {
                SimpleRBM simpleRBM = (SimpleRBM) rbmCl.getDbm().getRBM(rbmNr);
                simpleRBM.setWeightMatrix(simpleRBM.getWeightMatrix().plus(updates.get(rbmNr)));

            } else {
                LabeledSimpleRBM lrbm = (LabeledSimpleRBM) rbmCl.getDbm().getRBM(rbmNr);
                int rowSize = lrbm.getWeightMatrix().rowSize();
                Matrix weightUpdates = updates.get(rbmNr).viewPart(0, rowSize, 0, updates.get(rbmNr).columnSize());
                Matrix weightLabelUpdates = updates.get(rbmNr).viewPart(rowSize,
                        updates.get(rbmNr).rowSize() - rowSize, 0, updates.get(rbmNr).columnSize());

                lrbm.setWeightMatrix(lrbm.getWeightMatrix().plus(weightUpdates));
                lrbm.setWeightLabelMatrix(lrbm.getWeightLabelMatrix().plus(weightLabelUpdates));
            }
            lastUpdate[rbmNr] = updates.get(rbmNr);
        }
    }

}