hu.sztaki.incremental.ml.streaming.imsr.IMSR.java Source code

Java tutorial

Introduction

Here is the source code for hu.sztaki.incremental.ml.streaming.imsr.IMSR.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package hu.sztaki.incremental.ml.streaming.imsr;

import org.apache.commons.math.linear.SingularValueDecompositionImpl;
import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.function.sink.SinkFunction;

public class IMSR {

    // *************************************************************************
    // PROGRAM
    // *************************************************************************

    public static void main(String[] args) throws Exception {

        // set up the execution environment
        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        // get arguments
        String fileName = "src/test/resources/1.csv";
        int batchSize = 5;
        if (args.length > 0) {
            fileName = args[0];
            if (args.length > 1) {
                batchSize = Integer.parseInt(args[1]);
            }
        }

        // get input data
        DataStream<Tuple2<double[][], double[][]>> stream = env
                .addSource(new MatrixVectorPairSource(fileName, batchSize), 1);

        MatrixSink sink = new MatrixSink();

        stream.map(new MatrixMapper()).reduce(new MatrixSumReducer()).addSink(sink);

        // execute program
        env.execute("Streaming Linear Regression (IMSR)");
    }

    // *************************************************************************
    // USER FUNCTIONS
    // *************************************************************************

    public static final class MatrixMapper
            implements MapFunction<Tuple2<double[][], double[][]>, Tuple2<double[][], double[][]>> {

        private static final long serialVersionUID = -5984071416255204043L;

        @Override
        public Tuple2<double[][], double[][]> map(Tuple2<double[][], double[][]> value) throws Exception {
            Array2DRowRealMatrix X = new Array2DRowRealMatrix(value.f0);
            Array2DRowRealMatrix y = new Array2DRowRealMatrix(value.f1);
            Array2DRowRealMatrix XT = new Array2DRowRealMatrix(X.transpose().getData());
            Array2DRowRealMatrix XTX = XT.multiply(X);
            Array2DRowRealMatrix XTy = XT.multiply(y);
            Tuple2<double[][], double[][]> res = new Tuple2<double[][], double[][]>(XTX.getDataRef(),
                    XTy.getDataRef());
            return res;
        }

    }

    public static final class MatrixSumReducer implements ReduceFunction<Tuple2<double[][], double[][]>> {

        private static final long serialVersionUID = 1143426179541008899L;

        @Override
        public Tuple2<double[][], double[][]> reduce(Tuple2<double[][], double[][]> value1,
                Tuple2<double[][], double[][]> value2) throws Exception {
            Tuple2<double[][], double[][]> res = new Tuple2<double[][], double[][]>();
            Array2DRowRealMatrix M1 = new Array2DRowRealMatrix(value1.f0);
            Array2DRowRealMatrix M2 = new Array2DRowRealMatrix(value2.f0);
            Array2DRowRealMatrix v1 = new Array2DRowRealMatrix(value1.f1);
            Array2DRowRealMatrix v2 = new Array2DRowRealMatrix(value2.f1);
            res.f0 = M1.add(M2).getDataRef();
            res.f1 = v1.add(v2).getDataRef();
            return res;
        }

    }

    public static final class MatrixSink implements SinkFunction<Tuple2<double[][], double[][]>> {
        private static final long serialVersionUID = -7966965600616447076L;

        @Override
        public void invoke(Tuple2<double[][], double[][]> value) {
            Array2DRowRealMatrix M = new Array2DRowRealMatrix(value.f0);
            Array2DRowRealMatrix v = new Array2DRowRealMatrix(value.f1);
            Array2DRowRealMatrix invM = new Array2DRowRealMatrix(
                    new SingularValueDecompositionImpl(M).getSolver().getInverse().getData());
            Array2DRowRealMatrix beta = invM.multiply(v);
            printVector(beta);
        }

        private void printVector(RealMatrix m) {
            assert (Math.min(m.getColumnDimension(), m.getRowDimension()) == 1);
            if (m.getColumnDimension() > 1) {
                m = m.transpose();
            }
            for (int i = 0; i < m.getRowDimension(); i++) {
                System.out.print(m.getEntry(i, 0));
                System.out.print(" ");
            }
            System.out.println();
        }
    }

}