com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.cloudera.knittingboar.sgd;

import java.util.ArrayList;

import junit.framework.TestCase;

import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

import com.cloudera.knittingboar.utils.Utils;

/**
 * Mostly temporary tests used to debug components as we developed the system
 * 
 * @author jpatterson
 *
 */
public class TestParallelOnlineLogisticRegression extends TestCase {

    public void testCreateLR() {

        int categories = 2;
        int numFeatures = 5;
        double lambda = 1.0e-4;
        double learning_rate = 50;

        ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression(categories, numFeatures,
                new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

        assertEquals(plr.getLambda(), 1.0e-4);
    }

    public void testTrainMechanics() {

        int categories = 2;
        int numFeatures = 5;
        double lambda = 1.0e-4;
        double learning_rate = 10;

        ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression(categories, numFeatures,
                new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

        Vector input = new RandomAccessSparseVector(numFeatures);

        for (int x = 0; x < numFeatures; x++) {

            input.set(x, x);

        }

        plr.train(0, input);

        plr.train(0, input);

        plr.train(0, input);

    }

    public void testPOLRInternalBuffers() {

        System.out.println("testPOLRInternalBuffers --------------");

        int categories = 2;
        int numFeatures = 5;
        double lambda = 1.0e-4;
        double learning_rate = 10;

        ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>();

        for (int s = 0; s < 1; s++) {

            Vector input = new RandomAccessSparseVector(numFeatures);

            for (int x = 0; x < numFeatures; x++) {

                input.set(x, x);

            }

            trainingSet_0.add(input);

        } // for

        ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression(categories, numFeatures,
                new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

        System.out.println("Beta: ");
        //Utils.PrintVectorNonZero(plr_agent_0.getBeta().getRow(0));
        Utils.PrintVectorNonZero(plr_agent_0.getBeta().viewRow(0));

        System.out.println("\nGamma: ");
        //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0));
        Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));

        plr_agent_0.train(0, trainingSet_0.get(0));

        System.out.println("Beta: ");
        //Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().getRow(0));
        Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().viewRow(0));

        System.out.println("\nGamma: ");
        //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0));
        Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));

    }

    public void testLocalGradientFlush() {

        System.out.println("\n\n\ntestLocalGradientFlush --------------");

        int categories = 2;
        int numFeatures = 5;
        double lambda = 1.0e-4;
        double learning_rate = 10;

        ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>();

        for (int s = 0; s < 1; s++) {

            Vector input = new RandomAccessSparseVector(numFeatures);

            for (int x = 0; x < numFeatures; x++) {

                input.set(x, x);

            }

            trainingSet_0.add(input);

        } // for

        ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression(categories, numFeatures,
                new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

        plr_agent_0.train(0, trainingSet_0.get(0));

        System.out.println("\nGamma: ");
        Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));

        plr_agent_0.FlushGamma();

        System.out.println("Flushing Gamma ...... ");

        System.out.println("\nGamma: ");
        Utils.PrintVector(plr_agent_0.gamma.getMatrix().viewRow(0));

        for (int x = 0; x < numFeatures; x++) {

            assertEquals(plr_agent_0.gamma.getMatrix().get(0, x), 0.0);

        }

    }

}