com.caseystella.analytics.outlier.batch.rpca.RPCA.java Source code

Java tutorial

Introduction

Here is the source code for com.caseystella.analytics.outlier.batch.rpca.RPCA.java

Source

/**
 * Copyright (C) 2016 Hurence (support@hurence.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.caseystella.analytics.outlier.batch.rpca;

import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

public class RPCA {

    private RealMatrix X;
    private RealMatrix L;
    private RealMatrix S;
    private RealMatrix E;

    private double lpenalty;
    private double spenalty;

    private static final int MAX_ITERS = 228;

    public RPCA(double[][] data, double lpenalty, double spenalty) {
        this.X = MatrixUtils.createRealMatrix(data);
        this.lpenalty = lpenalty;
        this.spenalty = spenalty;
        initMatrices();
        computeRSVD();
    }

    public RPCA(RealMatrix X, double lpenalty, double spenalty) {
        this.X = X;
        this.lpenalty = lpenalty;
        this.spenalty = spenalty;
        initMatrices();
        computeRSVD();
    }

    private void initMatrices() {
        this.L = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension());
        this.S = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension());
        this.E = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension());
    }

    private void computeRSVD() {
        double mu = X.getColumnDimension() * X.getRowDimension() / (4 * l1norm(X.getData()));
        double objPrev = 0.5 * Math.pow(X.getFrobeniusNorm(), 2);
        double obj = objPrev;
        double tol = 1e-8 * objPrev;
        double diff = 2 * tol;
        int iter = 0;

        while (diff > tol && iter < MAX_ITERS) {
            double nuclearNorm = computeS(mu);
            double l1Norm = computeL(mu);
            double l2Norm = computeE();

            obj = computeObjective(nuclearNorm, l1Norm, l2Norm);
            diff = Math.abs(objPrev - obj);
            objPrev = obj;

            mu = computeDynamicMu();

            iter = iter + 1;
        }
    }

    private double[] softThreshold(double[] x, double penalty) {
        for (int i = 0; i < x.length; i++) {
            x[i] = Math.signum(x[i]) * Math.max(Math.abs(x[i]) - penalty, 0);
        }
        return x;
    }

    private double[][] softThreshold(double[][] x, double penalty) {
        for (int i = 0; i < x.length; i++) {
            for (int j = 0; j < x[i].length; j++) {
                x[i][j] = Math.signum(x[i][j]) * Math.max(Math.abs(x[i][j]) - penalty, 0);
            }
        }
        return x;
    }

    private double sum(double[] x) {
        double sum = 0;
        for (int i = 0; i < x.length; i++)
            sum += x[i];
        return (sum);
    }

    private double l1norm(double[][] x) {
        double l1norm = 0;
        for (int i = 0; i < x.length; i++) {
            for (int j = 0; j < x[i].length; j++) {
                l1norm += Math.abs(x[i][j]);
            }
        }
        return l1norm;
    }

    private double computeL(double mu) {
        double LPenalty = lpenalty * mu;
        SingularValueDecomposition svd = new SingularValueDecomposition(X.subtract(S));
        double[] penalizedD = softThreshold(svd.getSingularValues(), LPenalty);
        RealMatrix D_matrix = MatrixUtils.createRealDiagonalMatrix(penalizedD);
        L = svd.getU().multiply(D_matrix).multiply(svd.getVT());
        return sum(penalizedD) * LPenalty;
    }

    private double computeS(double mu) {
        double SPenalty = spenalty * mu;
        double[][] penalizedS = softThreshold(X.subtract(L).getData(), SPenalty);
        S = MatrixUtils.createRealMatrix(penalizedS);
        return l1norm(penalizedS) * SPenalty;
    }

    private double computeE() {
        E = X.subtract(L).subtract(S);
        double norm = E.getFrobeniusNorm();
        return Math.pow(norm, 2);
    }

    private double computeObjective(double nuclearnorm, double l1norm, double l2norm) {
        return 0.5 * l2norm + nuclearnorm + l1norm;
    }

    private double computeDynamicMu() {
        int m = E.getRowDimension();
        int n = E.getColumnDimension();

        double E_sd = standardDeviation(E.getData());
        double mu = E_sd * Math.sqrt(2 * Math.max(m, n));

        return Math.max(.01, mu);
    }

    /*private double MedianAbsoluteDeviation(double[][] x) {
       DescriptiveStatistics stats = new DescriptiveStatistics();
       for (int i = 0; i < x.length; i ++)
     for (int j = 0; j < x[i].length; j++)
        stats.addValue(x[i][j]);
       double median = stats.getPercentile(50);
        
       DescriptiveStatistics absoluteDeviationStats = new DescriptiveStatistics();
       for (int i = 0; i < x.length; i ++)
     for (int j = 0; j < x[i].length; j++)
        absoluteDeviationStats.addValue(Math.abs(x[i][j] - median));
        
       return absoluteDeviationStats.getPercentile(50) * 1.4826;
    }*/

    private double standardDeviation(double[][] x) {
        DescriptiveStatistics stats = new DescriptiveStatistics();
        for (int i = 0; i < x.length; i++)
            for (int j = 0; j < x[i].length; j++)
                stats.addValue(x[i][j]);
        return stats.getStandardDeviation();
    }

    public RealMatrix getL() {
        return L;
    }

    public RealMatrix getS() {
        return S;
    }

    public RealMatrix getE() {
        return E;
    }

}