Java tutorial
/* * Copyright 2011-2016 joptimizer.com * * This work is licensed under the Creative Commons Attribution-NoDerivatives 4.0 * International License. To view a copy of this license, visit * * http://creativecommons.org/licenses/by-nd/4.0/ * * or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. */ package com.joptimizer.solvers; import junit.framework.TestCase; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import cern.colt.matrix.DoubleFactory1D; import cern.colt.matrix.DoubleFactory2D; import cern.colt.matrix.DoubleMatrix1D; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.linalg.Algebra; import cern.colt.matrix.linalg.Property; import cern.jet.math.Functions; /** * @author alberto trivellato (alberto.trivellato@gmail.com) */ public class BasicKKTSolverTest extends TestCase { private Algebra ALG = Algebra.DEFAULT; private DoubleFactory1D F1 = DoubleFactory1D.dense; private DoubleFactory2D F2 = DoubleFactory2D.dense; private Property P = Property.TWELVE; private Log log = LogFactory.getLog(this.getClass().getName()); public void testSolveSimple() throws Exception { log.debug("testSolveSimple"); double[][] HMatrix = new double[][] { { 3 } }; double[][] AMatrix = new double[][] { { 2 } }; DoubleMatrix2D H = F2.make(HMatrix); DoubleMatrix2D A = F2.make(AMatrix); DoubleMatrix2D AT = ALG.transpose(A.copy()); DoubleMatrix1D g = F1.make(1, -3); DoubleMatrix1D h = F1.make(1, 0); KKTSolver solver = new BasicKKTSolver(); solver.setHMatrix(H); solver.setAMatrix(A); solver.setGVector(g); solver.setHVector(h); DoubleMatrix1D[] sol = solver.solve(); DoubleMatrix1D v = sol[0]; DoubleMatrix1D w = sol[1]; log.debug("v: " + ArrayUtils.toString(v.toArray())); log.debug("w: " + ArrayUtils.toString(w.toArray())); DoubleMatrix1D a = ALG.mult(H, v).assign(ALG.mult(AT, w), Functions.plus).assign(g, Functions.plus); DoubleMatrix1D b = ALG.mult(A, v).assign(h, Functions.plus); log.debug("a: " + ArrayUtils.toString(a.toArray())); log.debug("b: " + ArrayUtils.toString(b.toArray())); for (int i = 0; i < a.size(); i++) { assertEquals(0, a.get(i), 1.E-14); } for (int i = 0; i < b.size(); i++) { assertEquals(0, b.get(i), 1.E-14); } } public void testSolve2() throws Exception { log.debug("testSolve2"); double[][] HMatrix = new double[][] { { 1.68, 0.34, 0.38 }, { 0.34, 3.09, -1.59 }, { 0.38, -1.59, 1.54 } }; double[][] AMatrix = new double[][] { { 1, 2, 3 } }; DoubleMatrix2D H = F2.make(HMatrix); DoubleMatrix2D A = F2.make(AMatrix); DoubleMatrix2D AT = ALG.transpose(A.copy()); DoubleMatrix1D g = F1.make(new double[] { 2, 5, 1 }); DoubleMatrix1D h = F1.make(new double[] { 1 }); KKTSolver solver = new BasicKKTSolver(); solver.setHMatrix(H); solver.setAMatrix(A); solver.setGVector(g); solver.setHVector(h); DoubleMatrix1D[] sol = solver.solve(); DoubleMatrix1D v = sol[0]; DoubleMatrix1D w = sol[1]; log.debug("v: " + ArrayUtils.toString(v.toArray())); log.debug("w: " + ArrayUtils.toString(w.toArray())); DoubleMatrix1D a = ALG.mult(H, v).assign(ALG.mult(AT, w), Functions.plus).assign(g, Functions.plus); DoubleMatrix1D b = ALG.mult(A, v).assign(h, Functions.plus); log.debug("a: " + ArrayUtils.toString(a.toArray())); log.debug("b: " + ArrayUtils.toString(b.toArray())); for (int i = 0; i < a.size(); i++) { assertEquals(0, a.get(i), 1.E-14); } for (int i = 0; i < b.size(); i++) { assertEquals(0, b.get(i), 1.E-14); } } }