Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package org.micromanager.plugins.magellan.propsandcovariants; import java.awt.geom.AffineTransform; import java.awt.geom.Point2D; import java.io.File; import java.io.FileNotFoundException; import java.rmi.activation.ActivationSystem; import java.util.Arrays; import java.util.Random; import java.util.Scanner; import org.apache.commons.math.linear.Array2DRowRealMatrix; import org.apache.commons.math.linear.MatrixUtils; import org.apache.commons.math3.geometry.euclidean.threed.Vector3D; import org.micromanager.plugins.magellan.acq.AcquisitionEvent; import org.micromanager.plugins.magellan.bidc.JavaLayerImageConstructor; import org.micromanager.plugins.magellan.coordinates.AffineUtils; import org.micromanager.plugins.magellan.coordinates.XYStagePosition; import org.micromanager.plugins.magellan.main.Magellan; import org.micromanager.plugins.magellan.misc.Log; import org.micromanager.plugins.magellan.surfacesandregions.SurfaceInterpolator; public class LaserPredNet { private static final double SEARCH_START_DIST = 400.0; private static final double SEARCH_TOLERANCE = 2.0; private static final int N_THETA_ANGLES = 12; private static final int N_HIST_BINS = 12; private static final double PHI = 0.3491; private static final int FOV_LASER_MODULATION_RESOLUTION = 16; private static final int N_HIDDENS = 200; private static final int N_INPUTS = 15; Array2DRowRealMatrix W1_, B1_, W2_, B2_; private double[] distanceMeans_, distanceSDs_; private double[][] testValues_; private double[] testValuesOutput_; private double brightness_; private double[] binedges_; public LaserPredNet(String filename, double brightness) throws FileNotFoundException { readModel(filename); //init log bins double binmax = 350; binedges_ = new double[N_HIST_BINS + 1]; for (int b = 0; b < N_HIST_BINS + 1; b++) { double linearBin = (1.0 / (double) N_HIST_BINS) * b; binedges_[b] = Math.pow(linearBin, 1.5) * binmax; } brightness_ = brightness; } public byte[] getExcitations(AcquisitionEvent e, SurfaceInterpolator surf) throws InterruptedException { XYStagePosition xyPos = e.xyPosition_; double zPos = e.zPosition_; Point2D.Double[] corners = xyPos.getFullTileCorners(); double tileSize = Math.abs(corners[2].x - corners[0].x); int pixelDim = JavaLayerImageConstructor.getInstance().getImageHeight(); AffineTransform posTransform = AffineUtils.getAffineTransform(getCurrentPixelSizeConfig(), xyPos.getCenter().x, xyPos.getCenter().y); double[][] designMat = new double[FOV_LASER_MODULATION_RESOLUTION * FOV_LASER_MODULATION_RESOLUTION][N_HIST_BINS + 3]; // designMatrix = [designMatrix tilePosition brightness]; for (int r = 0; r < designMat.length; r++) { //calculate position for this point in FOV int xPosPix = (int) (((r % FOV_LASER_MODULATION_RESOLUTION) / (double) (FOV_LASER_MODULATION_RESOLUTION - 1) - 0.5) * pixelDim); int yPosPix = (int) (((r / FOV_LASER_MODULATION_RESOLUTION) / (double) (FOV_LASER_MODULATION_RESOLUTION - 1) - 0.5) * pixelDim); Point2D.Double stageCoordPos = new Point2D.Double(); posTransform.transform(new Point2D.Double(xPosPix, yPosPix), stageCoordPos); //calculate histogram double[] hist = new double[N_HIST_BINS]; for (int thetaIndex = 0; thetaIndex < N_THETA_ANGLES; thetaIndex++) { double dist = getSampledDistancesToSurface(thetaIndex, stageCoordPos.x, stageCoordPos.y, zPos, surf); //add count to hist for (int binIndex = 0; binIndex < N_HIST_BINS; binIndex++) { if (dist < binedges_[binIndex + 1]) { hist[binIndex]++; break; } } } //standardize histogram for (int i = 0; i < hist.length; i++) { hist[i] = (hist[i] - distanceMeans_[i]) / distanceSDs_[i]; } for (int c = 0; c < designMat[0].length; c++) { if (c < N_HIST_BINS) { //add in normalized distance histogram designMat[r][c] = hist[c]; } else if (c < N_HIST_BINS + 1) { // x position double xPos = (r % FOV_LASER_MODULATION_RESOLUTION) / (double) (FOV_LASER_MODULATION_RESOLUTION - 1); designMat[r][c] = xPos - 0.5; } else if (c < N_HIST_BINS + 2) { // y position double yPos = (r / FOV_LASER_MODULATION_RESOLUTION) / (double) (FOV_LASER_MODULATION_RESOLUTION - 1); designMat[r][c] = yPos - 0.5; } else { designMat[r][c] = brightness_; } } } //use NN to predict return forwardPass(designMat); } public double getBrightness() { return brightness_; } public byte[] forwardPass(double[][] x) { double[] ones = new double[x.length]; Arrays.fill(ones, 1.0); Array2DRowRealMatrix onesMat = new Array2DRowRealMatrix(ones); //assume x is properly normalized new Array2DRowRealMatrix(x[0]); Array2DRowRealMatrix xMat = (Array2DRowRealMatrix) MatrixUtils.createRealMatrix(x); Array2DRowRealMatrix h = xMat.multiply(W1_).add(onesMat.multiply(B1_)); relu(h); Array2DRowRealMatrix z = (Array2DRowRealMatrix) h.multiply(W2_.transpose()).add(onesMat.multiply(B2_)); byte[] powers = new byte[z.getRowDimension() * z.getColumnDimension()]; for (int i = 0; i < powers.length; i++) { powers[i] = (byte) Math.max(0, Math.min(255, z.getEntry(i, 0))); } return powers; } private static void relu(Array2DRowRealMatrix activations) { for (int r = 0; r < activations.getRowDimension(); r++) { for (int c = 0; c < activations.getColumnDimension(); c++) { if (activations.getEntry(r, c) < 0) { activations.setEntry(r, c, 0.0); } } } } private void readModel(String filename) throws FileNotFoundException { Scanner s = new Scanner(new File(filename)); double[][] w1 = new double[N_INPUTS][N_HIDDENS]; double[][] b1 = new double[1][N_HIDDENS]; double[][] w2 = new double[1][N_HIDDENS]; double[][] b2 = new double[1][1]; double[][] var = null; int index = 0; int matCount = 0; while (s.hasNext()) { String line = s.nextLine(); if (line.toLowerCase().startsWith("fc") || line.toLowerCase().startsWith("output")) { //new variable if (matCount == 0) { var = w1; } else if (matCount == 1) { var = b1; } else if (matCount == 2) { var = w2; } else { var = b2; } matCount++; index = 0; } else if (line.toLowerCase().startsWith("distance")) { break; } else { String[] entries = line.split(","); for (int i = 0; i < entries.length; i++) { try { var[index / var[0].length][index % var[0].length] = Double.parseDouble(entries[i]); } catch (Exception e) { int t = 6; } index++; } } } String meanStr = s.nextLine(); // means String[] entries = meanStr.split(","); distanceMeans_ = new double[N_HIST_BINS]; for (int i = 0; i < entries.length; i++) { distanceMeans_[i] = Double.parseDouble(entries[i]); } s.nextLine(); // burn SD title String sdStr = s.nextLine(); entries = sdStr.split(","); distanceSDs_ = new double[N_HIST_BINS]; for (int i = 0; i < entries.length; i++) { distanceSDs_[i] = Double.parseDouble(entries[i]); } s.nextLine(); // burn test values title int numTestVals = 4; testValuesOutput_ = new double[numTestVals]; testValues_ = new double[numTestVals][N_HIST_BINS + 3]; for (int i = 0; i < numTestVals; i++) { String valsString = s.nextLine(); entries = valsString.split(","); for (int k = 0; k < entries.length; k++) { testValues_[i][k] = Double.parseDouble(entries[k]); } testValuesOutput_[i] = Double.parseDouble(s.nextLine()); } //convert model to Apache commons matrices W1_ = (Array2DRowRealMatrix) MatrixUtils.createRealMatrix(w1); B1_ = (Array2DRowRealMatrix) MatrixUtils.createRealMatrix(b1); W2_ = (Array2DRowRealMatrix) MatrixUtils.createRealMatrix(w2); B2_ = (Array2DRowRealMatrix) MatrixUtils.createRealMatrix(b2); //Run tests byte[] output = forwardPass(testValues_); for (int k = 0; k < output.length; k++) { System.out.println("Calculated: " + (output[k] & 0xff) + "\tGround truth:" + testValuesOutput_[k]); } } /** * * @return return distance to surface interpolation based on x y and z points */ private static double getSampledDistancesToSurface(int angleIndex, double x, double y, double z, SurfaceInterpolator surface) throws InterruptedException { double dTheta = Math.PI * 2.0 / (double) N_THETA_ANGLES; Vector3D initialPoint = new Vector3D(x, y, z); double[] distances = new double[N_THETA_ANGLES]; double theta = angleIndex * dTheta; //calculate unit vector in theta phi direction Vector3D directionUnitVec = new Vector3D(Math.cos(theta) * Math.sin(PHI), Math.sin(theta) * Math.sin(PHI), Math.cos(PHI)).scalarMultiply(-1); //binary search double initialDist = SEARCH_START_DIST; //start with a point outside and then binary line search for the distance while (isWithinSurace(surface, initialPoint.add(directionUnitVec.scalarMultiply(initialDist)))) { initialDist *= 2; } return binarySearch(initialPoint, directionUnitVec, 0, initialDist, surface); } private static boolean isWithinSurace(SurfaceInterpolator surface, Vector3D point) throws InterruptedException { boolean defined = surface.waitForCurentInterpolation().isInterpDefined(point.getX(), point.getY()); if (!defined) { return false; } float interpVal = surface.waitForCurentInterpolation().getInterpolatedValue(point.getX(), point.getY()); return point.getZ() > interpVal; } private static double binarySearch(Vector3D initialPoint, Vector3D direction, double minDistance, double maxDistance, SurfaceInterpolator surf) throws InterruptedException { double halfDistance = (minDistance + maxDistance) / 2; //if the distance has been narrowed to a sufficiently small interval, return if (maxDistance - minDistance < SEARCH_TOLERANCE) { return halfDistance; } //check if point is above surface in Vector3D searchPoint = initialPoint.add(direction.scalarMultiply(halfDistance)); boolean withinSurface = isWithinSurace(surf, searchPoint); if (!withinSurface) { return binarySearch(initialPoint, direction, minDistance, halfDistance, surf); } else { return binarySearch(initialPoint, direction, halfDistance, maxDistance, surf); } } private static String getCurrentPixelSizeConfig() { try { return Magellan.getCore().getCurrentPixelSizeConfig(); } catch (Exception ex) { Log.log("couldnt get pixel size config"); throw new RuntimeException(); } } }