Java examples for Big Data:apache spark
Logistic regression based classification using apache spark
/*/* w w w. j a va 2 s . co m*/ * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.examples; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import java.io.Serializable; import java.util.Arrays; import java.util.Random; import java.util.regex.Pattern; /** * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ public final class JavaHdfsLR { private static final int D = 10; // Number of dimensions private static final Random rand = new Random(42); static void showWarning() { String warning = "WARN: This is a naive implementation of Logistic Regression " + "and is given as an example!\n" + "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " + "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " + "for more conventional use."; System.err.println(warning); } static class DataPoint implements Serializable { DataPoint(double[] x, double y) { this.x = x; this.y = y; } double[] x; double y; } static class ParsePoint implements Function<String, DataPoint> { private static final Pattern SPACE = Pattern.compile(" "); @Override public DataPoint call(String line) { String[] tok = SPACE.split(line); double y = Double.parseDouble(tok[0]); double[] x = new double[D]; for (int i = 0; i < D; i++) { x[i] = Double.parseDouble(tok[i + 1]); } return new DataPoint(x, y); } } static class VectorSum implements Function2<double[], double[], double[]> { @Override public double[] call(double[] a, double[] b) { double[] result = new double[D]; for (int j = 0; j < D; j++) { result[j] = a[j] + b[j]; } return result; } } static class ComputeGradient implements Function<DataPoint, double[]> { private final double[] weights; ComputeGradient(double[] weights) { this.weights = weights; } @Override public double[] call(DataPoint p) { double[] gradient = new double[D]; for (int i = 0; i < D; i++) { double dot = dot(weights, p.x); gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i]; } return gradient; } } public static double dot(double[] a, double[] b) { double x = 0; for (int i = 0; i < D; i++) { x += a[i] * b[i]; } return x; } public static void printWeights(double[] a) { System.out.println(Arrays.toString(a)); } public static void main(String[] args) { if (args.length < 2) { System.err.println("Usage: JavaHdfsLR <file> <iters>"); System.exit(1); } showWarning(); SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD<String> lines = sc.textFile(args[0]); JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache(); int ITERATIONS = Integer.parseInt(args[1]); // Initialize w to a random value double[] w = new double[D]; for (int i = 0; i < D; i++) { w[i] = 2 * rand.nextDouble() - 1; } System.out.print("Initial w: "); printWeights(w); for (int i = 1; i <= ITERATIONS; i++) { System.out.println("On iteration " + i); double[] gradient = points.map(new ComputeGradient(w)).reduce( new VectorSum()); for (int j = 0; j < D; j++) { w[j] -= gradient[j]; } } System.out.print("Final w: "); printWeights(w); sc.stop(); } }