org.apache.horn.core.AbstractNeuralNetwork.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.horn.core.AbstractNeuralNetwork.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.horn.core;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.net.URISyntaxException;

import org.apache.commons.lang.SerializationUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;

/**
 * NeuralNetwork defines the general operations for all the derivative models.
 * Typically, all derivative models such as Linear Regression, Logistic
 * Regression, and Multilayer Perceptron consist of neurons and the weights
 * between neurons.
 * 
 */
public abstract class AbstractNeuralNetwork implements Writable {

    protected HamaConfiguration conf;
    protected FileSystem fs;

    private static final float DEFAULT_LEARNING_RATE = 0.5f;

    protected float learningRate;
    protected boolean learningRateDecay = false;

    // the name of the model
    protected String modelType;
    // the path to store the model
    protected String modelPath;

    protected FloatFeatureTransformer featureTransformer;

    public AbstractNeuralNetwork() {
        this.learningRate = DEFAULT_LEARNING_RATE;
        this.modelType = this.getClass().getSimpleName();
        this.featureTransformer = new FloatFeatureTransformer();
    }

    public AbstractNeuralNetwork(HamaConfiguration conf, String modelPath) {
        try {
            this.conf = conf;
            this.modelPath = modelPath;
            this.readFromModel();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * Set the degree of aggression during model training, a large learning rate
     * can increase the training speed, but it also decrease the chance of model
     * converge. Recommend in range (0, 0.3).
     * 
     * @param learningRate
     */
    public void setLearningRate(float learningRate) {
        Preconditions.checkArgument(learningRate > 0, "Learning rate must be larger than 0.");
        this.learningRate = learningRate;
    }

    public float getLearningRate() {
        return this.learningRate;
    }

    public void isLearningRateDecay(boolean decay) {
        this.learningRateDecay = decay;
    }

    public String getModelType() {
        return this.modelType;
    }

    /**
     * Train the model with the path of given training data and parameters.
     * 
     * @param dataInputPath The path of the training data.
     * @param trainingParams The parameters for training.
     * @throws InterruptedException
     * @throws ClassNotFoundException
     * @throws IOException
     */
    public BSPJob train(HamaConfiguration conf) throws ClassNotFoundException, IOException, InterruptedException {
        Preconditions.checkArgument(this.modelPath != null, "Please set the model path before training.");
        // train with BSP job
        return trainInternal(conf);
    }

    /**
     * Train the model with the path of given training data and parameters.
     * 
     * @param dataInputPath
     * @param trainingParams
     */
    protected abstract BSPJob trainInternal(HamaConfiguration conf)
            throws IOException, InterruptedException, ClassNotFoundException;

    /**
     * Read the model meta-data from the specified location.
     * 
     * @throws IOException
     */
    protected void readFromModel() throws IOException {
        Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
        Configuration conf = new Configuration();
        FSDataInputStream is = null;
        try {
            URI uri = new URI(this.modelPath);
            FileSystem fs = FileSystem.get(uri, conf);
            is = new FSDataInputStream(fs.open(new Path(modelPath)));
            this.readFields(is);
        } catch (URISyntaxException e) {
            e.printStackTrace();
        } finally {
            Closeables.close(is, false);
        }
    }

    /**
     * Write the model data to specified location.
     * 
     * @throws IOException
     */
    public void writeModelToFile() throws IOException {
        Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
        Configuration conf = new Configuration();
        FSDataOutputStream is = null;
        try {
            URI uri = new URI(this.modelPath);
            FileSystem fs = FileSystem.get(uri, conf);
            is = fs.create(new Path(this.modelPath), true);
            this.write(is);
        } catch (URISyntaxException e) {
            e.printStackTrace();
        }

        Closeables.close(is, false);
    }

    /**
     * Set the model path.
     * 
     * @param modelPath
     */
    public void setModelPath(String modelPath) {
        this.modelPath = modelPath;
    }

    /**
     * Get the model path.
     * 
     * @return the path to store the model.
     */
    public String getModelPath() {
        return this.modelPath;
    }

    @SuppressWarnings({ "rawtypes", "unchecked" })
    @Override
    public void readFields(DataInput input) throws IOException {
        // read model type
        this.modelType = WritableUtils.readString(input);
        // read learning rate
        this.learningRate = input.readFloat();
        // read model path
        this.modelPath = WritableUtils.readString(input);

        if (this.modelPath.equals("null")) {
            this.modelPath = null;
        }

        // read feature transformer
        int bytesLen = input.readInt();
        byte[] featureTransformerBytes = new byte[bytesLen];
        for (int i = 0; i < featureTransformerBytes.length; ++i) {
            featureTransformerBytes[i] = input.readByte();
        }

        Class<? extends FloatFeatureTransformer> featureTransformerCls = (Class<? extends FloatFeatureTransformer>) SerializationUtils
                .deserialize(featureTransformerBytes);

        Constructor[] constructors = featureTransformerCls.getDeclaredConstructors();
        Constructor constructor = constructors[0];

        try {
            this.featureTransformer = (FloatFeatureTransformer) constructor.newInstance(new Object[] {});
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (IllegalArgumentException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void write(DataOutput output) throws IOException {
        // write model type
        WritableUtils.writeString(output, modelType);
        // write learning rate
        output.writeFloat(learningRate);
        // write model path
        if (this.modelPath != null) {
            WritableUtils.writeString(output, modelPath);
        } else {
            WritableUtils.writeString(output, "null");
        }

        // serialize the class
        Class<? extends FloatFeatureTransformer> featureTransformerCls = this.featureTransformer.getClass();
        byte[] featureTransformerBytes = SerializationUtils.serialize(featureTransformerCls);
        output.writeInt(featureTransformerBytes.length);
        output.write(featureTransformerBytes);
    }

    public void setFeatureTransformer(FloatFeatureTransformer featureTransformer) {
        this.featureTransformer = featureTransformer;
    }

    public FloatFeatureTransformer getFeatureTransformer() {
        return this.featureTransformer;
    }

}