Java tutorial
/** * Copyright (C) 2016 Hurence (support@hurence.com) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.caseystella.analytics.outlier.batch.rpca; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.SingularValueDecomposition; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; public class RPCA { private RealMatrix X; private RealMatrix L; private RealMatrix S; private RealMatrix E; private double lpenalty; private double spenalty; private static final int MAX_ITERS = 228; public RPCA(double[][] data, double lpenalty, double spenalty) { this.X = MatrixUtils.createRealMatrix(data); this.lpenalty = lpenalty; this.spenalty = spenalty; initMatrices(); computeRSVD(); } public RPCA(RealMatrix X, double lpenalty, double spenalty) { this.X = X; this.lpenalty = lpenalty; this.spenalty = spenalty; initMatrices(); computeRSVD(); } private void initMatrices() { this.L = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension()); this.S = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension()); this.E = MatrixUtils.createRealMatrix(this.X.getRowDimension(), this.X.getColumnDimension()); } private void computeRSVD() { double mu = X.getColumnDimension() * X.getRowDimension() / (4 * l1norm(X.getData())); double objPrev = 0.5 * Math.pow(X.getFrobeniusNorm(), 2); double obj = objPrev; double tol = 1e-8 * objPrev; double diff = 2 * tol; int iter = 0; while (diff > tol && iter < MAX_ITERS) { double nuclearNorm = computeS(mu); double l1Norm = computeL(mu); double l2Norm = computeE(); obj = computeObjective(nuclearNorm, l1Norm, l2Norm); diff = Math.abs(objPrev - obj); objPrev = obj; mu = computeDynamicMu(); iter = iter + 1; } } private double[] softThreshold(double[] x, double penalty) { for (int i = 0; i < x.length; i++) { x[i] = Math.signum(x[i]) * Math.max(Math.abs(x[i]) - penalty, 0); } return x; } private double[][] softThreshold(double[][] x, double penalty) { for (int i = 0; i < x.length; i++) { for (int j = 0; j < x[i].length; j++) { x[i][j] = Math.signum(x[i][j]) * Math.max(Math.abs(x[i][j]) - penalty, 0); } } return x; } private double sum(double[] x) { double sum = 0; for (int i = 0; i < x.length; i++) sum += x[i]; return (sum); } private double l1norm(double[][] x) { double l1norm = 0; for (int i = 0; i < x.length; i++) { for (int j = 0; j < x[i].length; j++) { l1norm += Math.abs(x[i][j]); } } return l1norm; } private double computeL(double mu) { double LPenalty = lpenalty * mu; SingularValueDecomposition svd = new SingularValueDecomposition(X.subtract(S)); double[] penalizedD = softThreshold(svd.getSingularValues(), LPenalty); RealMatrix D_matrix = MatrixUtils.createRealDiagonalMatrix(penalizedD); L = svd.getU().multiply(D_matrix).multiply(svd.getVT()); return sum(penalizedD) * LPenalty; } private double computeS(double mu) { double SPenalty = spenalty * mu; double[][] penalizedS = softThreshold(X.subtract(L).getData(), SPenalty); S = MatrixUtils.createRealMatrix(penalizedS); return l1norm(penalizedS) * SPenalty; } private double computeE() { E = X.subtract(L).subtract(S); double norm = E.getFrobeniusNorm(); return Math.pow(norm, 2); } private double computeObjective(double nuclearnorm, double l1norm, double l2norm) { return 0.5 * l2norm + nuclearnorm + l1norm; } private double computeDynamicMu() { int m = E.getRowDimension(); int n = E.getColumnDimension(); double E_sd = standardDeviation(E.getData()); double mu = E_sd * Math.sqrt(2 * Math.max(m, n)); return Math.max(.01, mu); } /*private double MedianAbsoluteDeviation(double[][] x) { DescriptiveStatistics stats = new DescriptiveStatistics(); for (int i = 0; i < x.length; i ++) for (int j = 0; j < x[i].length; j++) stats.addValue(x[i][j]); double median = stats.getPercentile(50); DescriptiveStatistics absoluteDeviationStats = new DescriptiveStatistics(); for (int i = 0; i < x.length; i ++) for (int j = 0; j < x[i].length; j++) absoluteDeviationStats.addValue(Math.abs(x[i][j] - median)); return absoluteDeviationStats.getPercentile(50) * 1.4826; }*/ private double standardDeviation(double[][] x) { DescriptiveStatistics stats = new DescriptiveStatistics(); for (int i = 0; i < x.length; i++) for (int j = 0; j < x[i].length; j++) stats.addValue(x[i][j]); return stats.getStandardDeviation(); } public RealMatrix getL() { return L; } public RealMatrix getS() { return S; } public RealMatrix getE() { return E; } }