Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.rbm.training; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.classifier.rbm.RBMClassifier; import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM; import org.apache.mahout.classifier.rbm.model.RBMModel; import org.apache.mahout.classifier.rbm.model.SimpleRBM; import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.common.Pair; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.common.iterator.sequencefile.PathFilters; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixWritable; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * The Class RBMClassifierTrainingJob. */ public class RBMClassifierTrainingJob extends AbstractJob { /** The Constant WEIGHT_UPDATES. */ public static final String WEIGHT_UPDATES = "weightupdates"; /** The Constant logger. */ private static final Logger logger = LoggerFactory.getLogger(RBMClassifierTrainingJob.class); /** The last update which is needed for use of the momentum. */ Matrix[] lastUpdate; /** The rbm classifier. */ RBMClassifier rbmCl = null; /** The number of iterations (epochs). */ int epochs; /** The learningrate. */ double learningrate; /** The momentum used. */ double momentum; /** monitor if true. */ boolean monitor; /** initial biases if true. */ boolean initbiases; /** train greedy if true. */ boolean greedy; /** finetune if true. */ boolean finetuning; /** The batches to train on. */ Path[] batches = null; /** The labelcount. */ int labelcount; /** The nr gibbs sampling. */ int nrGibbsSampling; /** The rbm nr to train. */ int rbmNrtoTrain; /** * The main method. * * @param args the arguments * @throws Exception the exception */ public static void main(String[] args) throws Exception { ToolRunner.run(new Configuration(), new RBMClassifierTrainingJob(), args); } /* (non-Javadoc) * @see org.apache.hadoop.util.Tool#run(java.lang.String[]) */ @Override public int run(String[] args) throws Exception { addInputOption(); addOutputOption(); addOption("epochs", "e", "number of training epochs through the trainingset", true); addOption("structure", "s", "comma-separated list of layer sizes", false); addOption("labelcount", "lc", "total count of labels existent in the training set", true); addOption("learningrate", "lr", "learning rate at the beginning of training", "0.005"); addOption("momentum", "m", "momentum of learning at the beginning", "0.5"); addOption("rbmnr", "nr", "rbm to train, < 0 means train all", "-1"); addOption("nrgibbs", "gn", "number of gibbs sampling used in contrastive divergence", "5"); addOption(new DefaultOptionBuilder().withLongName(DefaultOptionCreator.MAPREDUCE_METHOD).withRequired(false) .withDescription("Run training with map/reduce").withShortName("mr").create()); addOption(new DefaultOptionBuilder().withLongName("nogreedy").withRequired(false) .withDescription("Don't run greedy pre training").withShortName("ng").create()); addOption(new DefaultOptionBuilder().withLongName("nofinetuning").withRequired(false) .withDescription("Don't run fine tuning at the end").withShortName("nf").create()); addOption(new DefaultOptionBuilder().withLongName("nobiases").withRequired(false) .withDescription("Don't initialize biases").withShortName("nb").create()); addOption(new DefaultOptionBuilder().withLongName("monitor").withRequired(false) .withDescription("If present, errors can be monitored in cosole").withShortName("mon").create()); addOption(DefaultOptionCreator.overwriteOption().create()); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } Path input = getInputPath(); Path output = getOutputPath(); FileSystem fs = FileSystem.get(output.toUri(), getConf()); labelcount = Integer.parseInt(getOption("labelcount")); boolean local = !hasOption("mapreduce"); monitor = hasOption("monitor"); initbiases = !hasOption("nobiases"); finetuning = !hasOption("nofinetuning"); greedy = !hasOption("nogreedy"); if (fs.isFile(input)) batches = new Path[] { input }; else { FileStatus[] stati = fs.listStatus(input); batches = new Path[stati.length]; for (int i = 0; i < stati.length; i++) { batches[i] = stati[i].getPath(); } } epochs = Integer.valueOf(getOption("epochs")); learningrate = Double.parseDouble(getOption("learningrate")); momentum = Double.parseDouble(getOption("momentum")); rbmNrtoTrain = Integer.parseInt(getOption("rbmnr")); nrGibbsSampling = Integer.parseInt(getOption("nrgibbs")); boolean initialize = hasOption(DefaultOptionCreator.OVERWRITE_OPTION) || !fs.exists(output) || fs.listStatus(output).length <= 0; if (initialize) { String structure = getOption("structure"); if (structure == null || structure.isEmpty()) return -1; String[] layers = structure.split(","); if (layers.length < 2) { return -1; } int[] actualLayerSizes = new int[layers.length]; for (int i = 0; i < layers.length; i++) { actualLayerSizes[i] = Integer.parseInt(layers[i]); } rbmCl = new RBMClassifier(labelcount, actualLayerSizes); logger.info("New model initialized!"); } else { rbmCl = RBMClassifier.materialize(output, getConf()); logger.info("Model found and materialized!"); } HadoopUtil.setSerializations(getConf()); lastUpdate = new Matrix[rbmCl.getDbm().getRbmCount()]; if (initbiases) { //init biases! Vector biases = null; int counter = 0; for (Path batch : batches) { for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>( batch, getConf())) { if (biases == null) biases = record.getSecond().get().clone(); else biases.plus(record.getSecond().get()); counter++; } } if (biases == null) { logger.info("No training data found!"); return -1; } rbmCl.getDbm().getLayer(0).setBiases(biases.divide(counter)); logger.info("Biases initialized"); } //greedy pre training with gradually decreasing learningrates if (greedy) { if (!local) rbmCl.serialize(output, getConf()); double tempLearningrate = learningrate; if (rbmNrtoTrain < 0) //train all rbms for (int rbmNr = 0; rbmNr < rbmCl.getDbm().getRbmCount(); rbmNr++) { tempLearningrate = learningrate; //double weights if dbm was materialized, because it was halved after greedy pretraining if (!initialize && rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNr)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNr)).getWeightMatrix().times(2)); } for (int j = 0; j < epochs; j++) { logger.info("Greedy training, epoch " + (j + 1) + "\nCurrent learningrate: " + tempLearningrate); for (int b = 0; b < batches.length; b++) { tempLearningrate -= learningrate / (epochs * batches.length + epochs); if (local) { if (!trainGreedySeq(rbmNr, batches[b], j, tempLearningrate)) return -1; } else if (!trainGreedyMR(rbmNr, batches[b], j, tempLearningrate)) return -1; if (monitor && (batches.length > 19) && (b + 1) % (batches.length / 20) == 0) logger.info(rbmNr + "-RBM: " + Math.round(((double) b + 1) / batches.length * 100.0) + "% in epoch done!"); } logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training on rbm number " + rbmNr + " is done!"); if (monitor) { double error = rbmError(batches[0], rbmNr); logger.info( "Average reconstruction error on batch " + batches[0].getName() + ": " + error); } rbmCl.serialize(output, getConf()); } //weight normalization to avoid double counting if (rbmNr > 0 && rbmNr < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5)); } } else { //double weights if dbm was materialized, because it was halved after greedy pretraining if (!initialize && rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(2)); } //train just wanted rbm for (int j = 0; j < epochs; j++) { logger.info( "Greedy training, epoch " + (j + 1) + "\nCurrent learningrate: " + tempLearningrate); for (int b = 0; b < batches.length; b++) { tempLearningrate -= learningrate / (epochs * batches.length + epochs); if (local) { if (!trainGreedySeq(rbmNrtoTrain, batches[b], j, tempLearningrate)) return -1; } else if (!trainGreedyMR(rbmNrtoTrain, batches[b], j, tempLearningrate)) return -1; if (monitor && (batches.length > 19) && (b + 1) % (batches.length / 20) == 0) logger.info(rbmNrtoTrain + "-RBM: " + Math.round(((double) b + 1) / batches.length * 100.0) + "% in epoch done!"); } logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training is done!"); if (monitor) { double error = rbmError(batches[0], rbmNrtoTrain); logger.info("Average reconstruction error on batch " + batches[0].getName() + ": " + error); } } //weight normalization to avoid double counting if (rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5)); } } rbmCl.serialize(output, getConf()); logger.info("Pretraining done and model written to output"); } if (finetuning) { DeepBoltzmannMachine multiLayerDbm = null; double tempLearningrate = learningrate; //finetuning job for (int j = 0; j < epochs; j++) { for (int b = 0; b < batches.length; b++) { multiLayerDbm = rbmCl.initializeMultiLayerNN(); logger.info("Finetuning on batch " + batches[b].getName() + "\nCurrent learningrate: " + tempLearningrate); tempLearningrate -= learningrate / (epochs * batches.length + epochs); if (local) { if (!finetuneSeq(batches[b], j, multiLayerDbm, tempLearningrate)) return -1; } else if (!fintuneMR(batches[b], j, tempLearningrate)) return -1; logger.info("Finetuning: " + Math.round(((double) b + 1) / batches.length * 100.0) + "% in epoch done!"); } logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training is done!"); if (monitor) { double error = feedForwardError(multiLayerDbm, batches[0]); logger.info("Average discriminative error on batch " + batches[0].getName() + ": " + error); } } //final serialization rbmCl.serialize(output, getConf()); logger.info("RBM finetuning done and model written to output"); } if (executor != null) executor.shutdownNow(); return 0; } /** * The Class BackpropTrainingThread is the callable thread for the local backprop task. */ class BackpropTrainingThread implements Callable<Matrix[]> { /** The dbm. */ private DeepBoltzmannMachine dbm; /** The input. */ private Vector input; /** The label. */ private Vector label; /** The trainer. */ private BackPropTrainer trainer; /** * Instantiates a new backprop training thread. * * @param dbm the dbm * @param label the label * @param input the input * @param trainer the trainer */ public BackpropTrainingThread(DeepBoltzmannMachine dbm, Vector label, Vector input, BackPropTrainer trainer) { this.dbm = dbm; this.label = label; this.input = input; this.trainer = trainer; } /* (non-Javadoc) * @see java.util.concurrent.Callable#call() */ @Override public Matrix[] call() throws Exception { Matrix[] result = trainer.calculateWeightUpdates(dbm, input, label); Matrix[] weightUpdates = new Matrix[dbm.getRbmCount() - 1]; //write for each RBM i (key, number of rbm) the result and put together the last two //matrices since they refer to just one labeled rbm, which was split to two for the training for (int i = 0; i < result.length - 1; i++) { if (i == result.length - 2) { weightUpdates[i] = new DenseMatrix(result[i].rowSize() + result[i + 1].columnSize(), result[i].columnSize()); for (int j = 0; j < weightUpdates[i].rowSize(); j++) for (int k = 0; k < weightUpdates[i].columnSize(); k++) { if (j < result[i].rowSize()) weightUpdates[i].set(j, k, result[i].get(j, k)); else weightUpdates[i].set(j, k, result[i + 1].get(k, j - result[i].rowSize())); } } else weightUpdates[i] = result[i]; } return weightUpdates; } } /** The backprop training tasks. */ List<BackpropTrainingThread> backpropTrainingTasks; /** * Finetune locally. * * @param batch the batch * @param iteration the iteration * @param multiLayerDbm the multilayer dbm * @param learningrate the learningrate * @return true, if successful * @throws InterruptedException the interrupted exception * @throws ExecutionException the execution exception */ private boolean finetuneSeq(Path batch, int iteration, DeepBoltzmannMachine multiLayerDbm, double learningrate) throws InterruptedException, ExecutionException { Vector label = new DenseVector(labelcount); Map<Integer, Matrix> updates = new HashMap<Integer, Matrix>(); int batchsize = 0; //maximum number of threads that are used, I think 20 is ok int threadCount = 20; Matrix[] weightUpdates; //initialize the tasks, which are run by the executor if (backpropTrainingTasks == null) backpropTrainingTasks = new ArrayList<BackpropTrainingThread>(); //initialize the executor if not already done if (executor == null) executor = Executors.newFixedThreadPool(threadCount); for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { for (int i = 0; i < label.size(); i++) label.setQuick(i, 0); label.set(record.getFirst().get(), 1); BackPropTrainer trainer = new BackPropTrainer(learningrate); //prepare the tasks if (backpropTrainingTasks.size() < threadCount) backpropTrainingTasks.add(new BackpropTrainingThread(multiLayerDbm.clone(), label.clone(), record.getSecond().get(), trainer)); else { backpropTrainingTasks.get(batchsize % threadCount).input = record.getSecond().get(); backpropTrainingTasks.get(batchsize % threadCount).label = label.clone(); if (batchsize < threadCount) { backpropTrainingTasks.get(batchsize % threadCount).dbm = multiLayerDbm.clone(); } } //run the tasks and save results if (batchsize % threadCount == threadCount - 1) { List<Future<Matrix[]>> futureUpdates = executor.invokeAll(backpropTrainingTasks); for (int i = 0; i < futureUpdates.size(); i++) { weightUpdates = futureUpdates.get(i).get(); for (int j = 0; j < weightUpdates.length; j++) { if (updates.containsKey(j)) updates.put(j, weightUpdates[j].plus(updates.get(j))); else updates.put(j, weightUpdates[j]); } } } batchsize++; } //run remaining tasks if (batchsize % 20 != 0) { List<Future<Matrix[]>> futureUpdates = executor .invokeAll(backpropTrainingTasks.subList(0, (batchsize - 1) % 20)); for (int i = 0; i < futureUpdates.size(); i++) { weightUpdates = futureUpdates.get(i).get(); for (int j = 0; j < weightUpdates.length; j++) { if (updates.containsKey(j)) updates.put(j, weightUpdates[j].plus(updates.get(j))); else updates.put(j, weightUpdates[j]); } } } updateRbmCl(batchsize, (iteration == 0) ? 0 : momentum, updates); return true; } /** * Fintune using map/reduce. * * @param batch the batch * @param iteration the iteration * @param learningrate the learningrate * @return true, if successful * @throws IOException Signals that an I/O exception has occurred. * @throws InterruptedException the interrupted exception * @throws ClassNotFoundException the class not found exception */ private boolean fintuneMR(Path batch, int iteration, double learningrate) throws IOException, InterruptedException, ClassNotFoundException { //prepare and run finetune job long batchsize; HadoopUtil.delete(getConf(), getTempPath(WEIGHT_UPDATES)); HadoopUtil.cacheFiles(getOutputPath(), getConf()); Job trainDBM = prepareJob(batch, getTempPath(WEIGHT_UPDATES), SequenceFileInputFormat.class, DBMBackPropTrainingMapper.class, IntWritable.class, MatrixWritable.class, DBMBackPropTrainingReducer.class, IntWritable.class, MatrixWritable.class, SequenceFileOutputFormat.class); trainDBM.getConfiguration().set("labelcount", String.valueOf(labelcount)); trainDBM.getConfiguration().set("learningrate", String.valueOf(learningrate)); trainDBM.setCombinerClass(DBMBackPropTrainingReducer.class); if (!trainDBM.waitForCompletion(true)) return false; batchsize = trainDBM.getCounters().findCounter(DBMBackPropTrainingMapper.BATCHES.SIZE).getValue(); changeAndSaveModel(getOutputPath(), batchsize, (iteration == 0) ? 0 : momentum); return true; } /** * The Class GreedyTrainingThread. */ class GreedyTrainingThread implements Callable<Matrix> { /** The dbm. */ private DeepBoltzmannMachine dbm; /** The input. */ private Vector input; /** The label. */ private Vector label; /** The trainer. */ private CDTrainer trainer; /** The rbm nr to train. */ int rbmNr; /** * Instantiates a new greedy training thread. * * @param dbm the dbm * @param label the label * @param input the input * @param trainer the trainer * @param rbmNr the rbm nr */ public GreedyTrainingThread(DeepBoltzmannMachine dbm, Vector label, Vector input, CDTrainer trainer, int rbmNr) { this.dbm = dbm; this.label = label; this.input = input; this.trainer = trainer; this.rbmNr = rbmNr; } /* (non-Javadoc) * @see java.util.concurrent.Callable#call() */ @Override public Matrix call() throws Exception { Matrix updates = null; dbm.getRBM(0).getVisibleLayer().setActivations(input); for (int i = 0; i < rbmNr; i++) { //double the bottom up connection for initialization dbm.getRBM(i).exciteHiddenLayer(2, false); if (i == rbmNr - 1) //probabilities as activation for the data the rbm should train on dbm.getRBM(i).getHiddenLayer().setProbabilitiesAsActivation(); else dbm.getRBM(i).getHiddenLayer().updateNeurons(); } if (rbmNr == dbm.getRbmCount() - 1) { ((LabeledSimpleRBM) dbm.getRBM(rbmNr)).getSoftmaxLayer().setActivations(label); updates = trainer.calculateWeightUpdates((LabeledSimpleRBM) dbm.getRBM(rbmNr), true, false); } else { updates = trainer.calculateWeightUpdates((SimpleRBM) dbm.getRBM(rbmNr), false, rbmNr == 0); } return updates; } } /** The executor. */ private ExecutorService executor; /** The greedy training tasks. */ List<GreedyTrainingThread> greedyTrainingTasks; /** * Train greedy seq. * * @param rbmNr the rbm nr * @param batch the batch * @param iteration the iteration * @param learningrate the learningrate * @return true, if successful * @throws InterruptedException the interrupted exception * @throws ExecutionException the execution exception */ private boolean trainGreedySeq(int rbmNr, Path batch, int iteration, double learningrate) throws InterruptedException, ExecutionException { int batchsize = 0; DeepBoltzmannMachine dbm = rbmCl.getDbm(); Vector label = new DenseVector(labelcount); Matrix updates = null; //number of threads running the tasks int threadCount = 20; if (executor == null) executor = Executors.newFixedThreadPool(threadCount); if (greedyTrainingTasks == null) greedyTrainingTasks = new ArrayList<RBMClassifierTrainingJob.GreedyTrainingThread>(); for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { CDTrainer trainer = new CDTrainer(learningrate, nrGibbsSampling); label.assign(0); label.set(record.getFirst().get(), 1); //prepare the tasks if (greedyTrainingTasks.size() < threadCount) greedyTrainingTasks.add(new GreedyTrainingThread(dbm.clone(), label.clone(), record.getSecond().get(), trainer, rbmNr)); else { greedyTrainingTasks.get(batchsize % threadCount).input = record.getSecond().get(); greedyTrainingTasks.get(batchsize % threadCount).label = label.clone(); if (batchsize < threadCount) { greedyTrainingTasks.get(batchsize % threadCount).dbm = dbm.clone(); greedyTrainingTasks.get(batchsize % threadCount).rbmNr = rbmNr; } } //run tasks if (batchsize % threadCount == threadCount - 1) { List<Future<Matrix>> futureUpdates = executor.invokeAll(greedyTrainingTasks); for (int i = 0; i < futureUpdates.size(); i++) { if (updates == null) updates = futureUpdates.get(i).get(); else updates = updates.plus(futureUpdates.get(i).get()); } } batchsize++; } //run remaining tasks if (batchsize % 20 != 0) { List<Future<Matrix>> futureUpdates = executor .invokeAll(greedyTrainingTasks.subList(0, (batchsize - 1) % 20)); for (int i = 0; i < futureUpdates.size(); i++) { if (updates == null) updates = futureUpdates.get(i).get(); else updates = updates.plus(futureUpdates.get(i).get()); } } Map<Integer, Matrix> updateMap = new HashMap<Integer, Matrix>(); updateMap.put(rbmNr, updates); updateRbmCl(batchsize, (lastUpdate[rbmNr] == null) ? 0 : momentum, updateMap); return true; } /** * Train greedy mr. * * @param rbmNr the rbm nr * @param batch the batch * @param iteration the iteration * @param learningrate the learningrate * @return true, if successful * @throws IOException Signals that an I/O exception has occurred. * @throws InterruptedException the interrupted exception * @throws ClassNotFoundException the class not found exception */ private boolean trainGreedyMR(int rbmNr, Path batch, int iteration, double learningrate) throws IOException, InterruptedException, ClassNotFoundException { //run greedy pretraining as map reduce job long batchsize; HadoopUtil.delete(getConf(), getTempPath(WEIGHT_UPDATES)); HadoopUtil.cacheFiles(getOutputPath(), getConf()); Job trainRBM = prepareJob(batch, getTempPath(WEIGHT_UPDATES), SequenceFileInputFormat.class, RBMGreedyPreTrainingMapper.class, IntWritable.class, MatrixWritable.class, RBMGreedyPreTrainingReducer.class, IntWritable.class, MatrixWritable.class, SequenceFileOutputFormat.class); trainRBM.getConfiguration().set("rbmNr", String.valueOf(rbmNr)); trainRBM.getConfiguration().set("labelcount", String.valueOf(labelcount)); trainRBM.getConfiguration().set("learningrate", String.valueOf(learningrate)); trainRBM.getConfiguration().set("nrGibbsSampling", String.valueOf(nrGibbsSampling)); trainRBM.setCombinerClass(RBMGreedyPreTrainingReducer.class); if (!trainRBM.waitForCompletion(true)) return false; batchsize = trainRBM.getCounters().findCounter(RBMGreedyPreTrainingMapper.BATCH.SIZE).getValue(); changeAndSaveModel(getOutputPath(), batchsize, (lastUpdate[rbmNr] == null) ? 0 : momentum); return true; } /** * calculate classifiers error after 1 iteration of sampling. * * @param batch the batch * @return the error */ @SuppressWarnings("unused") private double classifierError(Path batch) { double error = 0; int counter = 0; Vector scores; for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { scores = rbmCl.classify(record.getSecond().get(), 1); error += 1 - scores.get(record.getFirst().get()); counter++; } error /= counter; return error; } /** * Calculates error of fann. * * @param feedForwardNet the feed forward net * @param batch the batch * @return the error */ private double feedForwardError(DeepBoltzmannMachine feedForwardNet, Path batch) { double error = 0; int counter = 0; RBMModel currentRBM = null; for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { feedForwardNet.getRBM(0).getVisibleLayer().setActivations(record.getSecond().get()); for (int i = 0; i < feedForwardNet.getRbmCount(); i++) { currentRBM = feedForwardNet.getRBM(i); currentRBM.exciteHiddenLayer(1, false); currentRBM.getHiddenLayer().setProbabilitiesAsActivation(); } error += 1 - currentRBM.getHiddenLayer().getActivations().get(record.getFirst().get()); counter++; } error /= counter; return error; } /** * Rbms reconstruction error. * * @param batch the batch * @param rbmNr the rbm nr * @return the error */ private double rbmError(Path batch, int rbmNr) { DeepBoltzmannMachine dbm = rbmCl.getDbm(); Vector label = new DenseVector( ((LabeledSimpleRBM) dbm.getRBM(dbm.getRbmCount() - 1)).getSoftmaxLayer().getNeuronCount()); double error = 0; int counter = 0; for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { dbm.getRBM(0).getVisibleLayer().setActivations(record.getSecond().get()); for (int i = 0; i < rbmNr; i++) { //double the bottom up connection for initialization dbm.getRBM(i).exciteHiddenLayer(2, false); if (i == rbmNr - 1) dbm.getRBM(i).getHiddenLayer().setProbabilitiesAsActivation(); else dbm.getRBM(i).getHiddenLayer().updateNeurons(); } if (dbm.getRBM(rbmNr) instanceof LabeledSimpleRBM) { label.assign(0); label.set(record.getFirst().get(), 1); ((LabeledSimpleRBM) dbm.getRBM(rbmNr)).getSoftmaxLayer().setActivations(label); } error += dbm.getRBM(rbmNr).getReconstructionError(); counter++; } error /= counter; return error; } /** * Change and save model. * * @param output the output * @param batchsize the batchsize * @param momentum the momentum * @throws IOException Signals that an I/O exception has occurred. */ private void changeAndSaveModel(Path output, long batchsize, double momentum) throws IOException { Map<Integer, Matrix> updates = new HashMap<Integer, Matrix>(); for (Pair<IntWritable, MatrixWritable> record : new SequenceFileDirIterable<IntWritable, MatrixWritable>( getTempPath(WEIGHT_UPDATES), PathType.LIST, PathFilters.partFilter(), getConf())) { if (!updates.containsKey(record.getFirst().get())) updates.put(record.getFirst().get(), record.getSecond().get()); else updates.put(record.getFirst().get(), record.getSecond().get().plus(updates.get(record.getFirst().get()))); } updateRbmCl(batchsize, momentum, updates); //serialization for mappers to have actual version of the dbm rbmCl.serialize(output, getConf()); } /** * Update rbm classifier with given updates. * * @param batchsize the batchsize * @param momentum the momentum * @param updates the updates */ private void updateRbmCl(long batchsize, double momentum, Map<Integer, Matrix> updates) { for (Integer rbmNr : updates.keySet()) { if (momentum > 0) updates.put(rbmNr, (updates.get(rbmNr).divide(batchsize).times(1 - momentum)) .plus(lastUpdate[rbmNr].times(momentum))); else updates.put(rbmNr, updates.get(rbmNr).divide(batchsize)); if (rbmNr < rbmCl.getDbm().getRbmCount() - 1) { SimpleRBM simpleRBM = (SimpleRBM) rbmCl.getDbm().getRBM(rbmNr); simpleRBM.setWeightMatrix(simpleRBM.getWeightMatrix().plus(updates.get(rbmNr))); } else { LabeledSimpleRBM lrbm = (LabeledSimpleRBM) rbmCl.getDbm().getRBM(rbmNr); int rowSize = lrbm.getWeightMatrix().rowSize(); Matrix weightUpdates = updates.get(rbmNr).viewPart(0, rowSize, 0, updates.get(rbmNr).columnSize()); Matrix weightLabelUpdates = updates.get(rbmNr).viewPart(rowSize, updates.get(rbmNr).rowSize() - rowSize, 0, updates.get(rbmNr).columnSize()); lrbm.setWeightMatrix(lrbm.getWeightMatrix().plus(weightUpdates)); lrbm.setWeightLabelMatrix(lrbm.getWeightLabelMatrix().plus(weightLabelUpdates)); } lastUpdate[rbmNr] = updates.get(rbmNr); } } }