Java tutorial
package trainableSegmentation; /** * This is a small Plugin that should perform better in segmentation than thresholding * The idea is to train a random forest classifier on given manual labels * and then classify the whole image * I try to keep parameters hidden from the user to make usage of the plugin * intuitive and easy. I decided that it is better to need more manual annotations * for training and do feature selection instead of having the user manually tune * all filters. * * ToDos: * - work with color features * - work on whole Stack * - delete annotations with a shortkey * - change training image * - do probability output (accessible?) and define threshold * - put thread solution to wiki http://fiji.sc/wiki/index.php/Developing_Fiji#Writing_plugins * * - give more feedback when classifier is trained or applied * * License: GPL * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU General Public License 2 * as published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. * * Authors: Verena Kaynig (verena.kaynig@inf.ethz.ch), Ignacio Arganda-Carreras (iarganda@mit.edu) * Albert Cardona (acardona@ini.phys.ethz.ch) */ import ij.IJ; import ij.ImageStack; import ij.plugin.PlugIn; import ij.process.FloatPolygon; import ij.process.FloatProcessor; import ij.process.ImageProcessor; import ij.process.LUT; import ij.gui.ImageWindow; import ij.gui.PolygonRoi; import ij.gui.Roi; import ij.gui.ShapeRoi; import ij.io.OpenDialog; import ij.io.SaveDialog; import ij.ImagePlus; import ij.WindowManager; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Enumeration; import java.util.List; import java.util.Vector; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.awt.AlphaComposite; import java.awt.Checkbox; import java.awt.Color; import java.awt.Component; import java.awt.Composite; import java.awt.Dimension; import java.awt.Graphics; import java.awt.GridBagConstraints; import java.awt.GridBagLayout; import java.awt.Insets; import java.awt.Panel; import java.awt.Rectangle; import java.awt.TextField; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.ComponentAdapter; import java.awt.event.ComponentEvent; import java.awt.event.KeyAdapter; import java.awt.event.KeyEvent; import java.awt.event.ItemEvent; import java.awt.event.ItemListener; import java.awt.event.KeyListener; import java.awt.event.WindowAdapter; import java.awt.event.WindowEvent; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.File; import javax.swing.BorderFactory; import javax.swing.JButton; import javax.swing.JOptionPane; import javax.swing.JPanel; import javax.swing.JFileChooser; import javax.swing.SwingUtilities; import weka.classifiers.AbstractClassifier; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instances; import fiji.util.gui.GenericDialogPlus; import fiji.util.gui.OverlayedImageCanvas; import hr.irb.fastRandomForest.FastRandomForest; public class Trainable_Segmentation implements PlugIn { final Composite transparency050 = AlphaComposite.getInstance(AlphaComposite.SRC_OVER, 0.50f); final Composite transparency025 = AlphaComposite.getInstance(AlphaComposite.SRC_OVER, 0.25f); int overlayOpacity = 33; Composite overlayAlpha = AlphaComposite.getInstance(AlphaComposite.SRC_OVER, overlayOpacity / 100f); /** maximum number of classes (labels) allowed on the GUI*/ private static final int MAX_NUM_CLASSES = 5; /** array of lists of Rois for each class */ private List<Roi>[] examples = new ArrayList[MAX_NUM_CLASSES]; /** image to be used in the training */ private ImagePlus trainingImage; /** image to display on the GUI, it includes the painted rois */ private ImagePlus displayImage; /** result image after classification */ private ImagePlus classifiedImage; /** features to be used in the training */ private FeatureStack featureStack = null; /** GUI window */ private CustomWindow win; /** array of number of traces per class */ private int traceCounter[] = new int[MAX_NUM_CLASSES]; /** flag to display the overlay image */ private boolean showColorOverlay; /** set of instances for the whole training image */ private Instances wholeImageData; /** set of instances from loaded data (previously saved segmentation) */ private Instances loadedTrainingData; /** current classifier */ private AbstractClassifier classifier = null; /** default classifier (Fast Random Forest) */ private FastRandomForest rf; /** flag to update the whole set of instances (used when there is any change on the features) */ private boolean updateWholeData = true; /** train classifier button */ JButton trainButton; /** toggle overlay button */ JButton overlayButton; /** create result button */ JButton resultButton; /** apply classifier button */ JButton applyButton; /** create probability image button */ JButton probimgButton; /** load data button */ JButton loadDataButton; /** save data button */ JButton saveDataButton; /** settings button */ JButton settingsButton; /** create new class button */ JButton addClassButton; /** array of roi list overlays to paint the transparent rois of each class */ RoiListOverlay[] roiOverlay; /** current segmentation result overlay */ ImageOverlay resultOverlay; /** available colors for available classes*/ final Color[] colors = new Color[] { Color.red, Color.green, Color.blue, Color.cyan, Color.magenta }; /** names of the current classes */ String[] classLabels = new String[] { "class 1", "class 2", "class 3", "class 4", "class 5" }; LUT overlayLUT; /** current number of classes */ private int numOfClasses = 2; /** array of trace lists for every class */ private java.awt.List exampleList[]; /** array of buttons for adding each trace class */ private JButton[] addExampleButton; // Random Forest parameters /** current number of trees in the fast random forest classifier */ private int numOfTrees = 200; /** current number of random features per tree in the fast random forest classifier */ private int randomFeatures = 2; /** list of class names on the loaded data */ ArrayList<String> loadedClassNames = null; /** executor service to launch threads for the plugin methods and events */ final ExecutorService exec = Executors.newFixedThreadPool(1); /** GUI/no GUI flag */ private boolean useGUI = true; /** * Basic constructor */ public Trainable_Segmentation() { this.useGUI = true; // Create overlay LUT final byte[] red = new byte[256]; final byte[] green = new byte[256]; final byte[] blue = new byte[256]; final int shift = 255 / MAX_NUM_CLASSES; for (int i = 0; i < 256; i++) { final int colorIndex = i / (shift + 1); //IJ.log("i = " + i + " color index = " + colorIndex); red[i] = (byte) colors[colorIndex].getRed(); green[i] = (byte) colors[colorIndex].getGreen(); blue[i] = (byte) colors[colorIndex].getBlue(); } overlayLUT = new LUT(red, green, blue); exampleList = new java.awt.List[MAX_NUM_CLASSES]; addExampleButton = new JButton[MAX_NUM_CLASSES]; roiOverlay = new RoiListOverlay[MAX_NUM_CLASSES]; resultOverlay = new ImageOverlay(); trainButton = new JButton("Train classifier"); trainButton.setToolTipText("Start training the classifier"); overlayButton = new JButton("Toggle overlay"); overlayButton.setToolTipText("Toggle between current segmentation and original image"); overlayButton.setEnabled(false); resultButton = new JButton("Create result"); resultButton.setToolTipText("Generate result image"); resultButton.setEnabled(false); applyButton = new JButton("Apply classifier"); applyButton.setToolTipText("Load data and apply current classifier"); applyButton.setEnabled(false); probimgButton = new JButton("Create probability image"); probimgButton.setToolTipText( "Instead of creating a segmentation, create a multi-channel image containing the probabilities for each class"); probimgButton.setEnabled(false); loadDataButton = new JButton("Load data"); loadDataButton.setToolTipText("Load previous segmentation from an ARFF file"); saveDataButton = new JButton("Save data"); saveDataButton.setToolTipText("Save current segmentation into an ARFF file"); addClassButton = new JButton("Create new class"); addClassButton.setToolTipText("Add one more label to mark different areas"); settingsButton = new JButton("Settings"); settingsButton.setToolTipText("Display settings dialog"); for (int i = 0; i < numOfClasses; i++) { examples[i] = new ArrayList<Roi>(); exampleList[i] = new java.awt.List(5); exampleList[i].setForeground(colors[i]); } showColorOverlay = false; // Initialization of Fast Random Forest classifier rf = new FastRandomForest(); rf.setNumTrees(numOfTrees); //this is the default that Breiman suggests //rf.setNumFeatures((int) Math.round(Math.sqrt(featureStack.getSize()))); //but this seems to work better rf.setNumFeatures(randomFeatures); rf.setSeed(123); classifier = rf; } /** * Listeners */ private ActionListener listener = new ActionListener() { public void actionPerformed(final ActionEvent e) { // listen to the buttons on separate threads not to block // the event dispatch thread exec.submit(new Runnable() { public void run() { if (e.getSource() == trainButton) { try { trainClassifier(); } catch (Exception e) { e.printStackTrace(); } } else if (e.getSource() == overlayButton) { toggleOverlay(); } else if (e.getSource() == resultButton) { showClassificationImage(); } else if (e.getSource() == applyButton) { applyClassifierToTestData(); } else if (e.getSource() == probimgButton) { createProbImgFromTestData(); } else if (e.getSource() == loadDataButton) { loadTrainingData(); } else if (e.getSource() == saveDataButton) { saveTrainingData(); } else if (e.getSource() == addClassButton) { addNewClass(); } else if (e.getSource() == settingsButton) { showSettingsDialog(); } else { for (int i = 0; i < numOfClasses; i++) { if (e.getSource() == exampleList[i]) { deleteSelected(e); break; } if (e.getSource() == addExampleButton[i]) { addExamples(i); break; } } } } }); } }; /** * Item listener for the trace lists */ private ItemListener itemListener = new ItemListener() { public void itemStateChanged(final ItemEvent e) { exec.submit(new Runnable() { public void run() { for (int i = 0; i < numOfClasses; i++) { if (e.getSource() == exampleList[i]) listSelected(e, i); } } }); } }; /** * Custom canvas to deal with zooming an panning */ private class CustomCanvas extends OverlayedImageCanvas { CustomCanvas(ImagePlus imp) { super(imp); Dimension dim = new Dimension(Math.min(512, imp.getWidth()), Math.min(512, imp.getHeight())); setMinimumSize(dim); setSize(dim.width, dim.height); setDstDimensions(dim.width, dim.height); addKeyListener(new KeyAdapter() { public void keyReleased(KeyEvent ke) { repaint(); } }); } //@Override public void setDrawingSize(int w, int h) { } public void setDstDimensions(int width, int height) { super.dstWidth = width; super.dstHeight = height; // adjust srcRect: can it grow/shrink? int w = Math.min((int) (width / magnification), imp.getWidth()); int h = Math.min((int) (height / magnification), imp.getHeight()); int x = srcRect.x; if (x + w > imp.getWidth()) x = w - imp.getWidth(); int y = srcRect.y; if (y + h > imp.getHeight()) y = h - imp.getHeight(); srcRect.setRect(x, y, w, h); repaint(); } //@Override public void paint(Graphics g) { Rectangle srcRect = getSrcRect(); double mag = getMagnification(); int dw = (int) (srcRect.width * mag); int dh = (int) (srcRect.height * mag); g.setClip(0, 0, dw, dh); super.paint(g); int w = getWidth(); int h = getHeight(); g.setClip(0, 0, w, h); // Paint away the outside g.setColor(getBackground()); g.fillRect(dw, 0, w - dw, h); g.fillRect(0, dh, w, h - dh); } } /** * Custom window to define the trainable segmentation GUI */ private class CustomWindow extends ImageWindow { /** layout for annotation panel */ GridBagLayout boxAnnotation = new GridBagLayout(); /** constraints for annotation panel */ GridBagConstraints annotationsConstraints = new GridBagConstraints(); /** Panel with class radio buttons and lists */ JPanel annotationsPanel = new JPanel(); JPanel buttonsPanel = new JPanel(); JPanel trainingJPanel = new JPanel(); JPanel optionsJPanel = new JPanel(); Panel all = new Panel(); CustomWindow(ImagePlus imp) { super(imp, new CustomCanvas(imp)); final CustomCanvas canvas = (CustomCanvas) getCanvas(); // add roi list overlays (one per class) for (int i = 0; i < MAX_NUM_CLASSES; i++) { roiOverlay[i] = new RoiListOverlay(); roiOverlay[i].setComposite(transparency050); ((OverlayedImageCanvas) ic).addOverlay(roiOverlay[i]); } // add result overlay resultOverlay.setComposite(overlayAlpha); ((OverlayedImageCanvas) ic).addOverlay(resultOverlay); // Remove the canvas from the window, to add it later removeAll(); setTitle("Trainable Segmentation"); // Annotations panel annotationsConstraints.anchor = GridBagConstraints.NORTHWEST; annotationsConstraints.gridwidth = 1; annotationsConstraints.gridheight = 1; annotationsConstraints.gridx = 0; annotationsConstraints.gridy = 0; annotationsPanel.setBorder(BorderFactory.createTitledBorder("Labels")); annotationsPanel.setLayout(boxAnnotation); for (int i = 0; i < numOfClasses; i++) { exampleList[i].addActionListener(listener); exampleList[i].addItemListener(itemListener); addExampleButton[i] = new JButton("Add to " + classLabels[i]); addExampleButton[i].setToolTipText("Add markings of label '" + classLabels[i] + "'"); annotationsConstraints.fill = GridBagConstraints.HORIZONTAL; annotationsConstraints.insets = new Insets(5, 5, 6, 6); boxAnnotation.setConstraints(addExampleButton[i], annotationsConstraints); annotationsPanel.add(addExampleButton[i]); annotationsConstraints.gridy++; annotationsConstraints.insets = new Insets(0, 0, 0, 0); boxAnnotation.setConstraints(exampleList[i], annotationsConstraints); annotationsPanel.add(exampleList[i]); annotationsConstraints.gridy++; } // Select first class addExampleButton[0].setSelected(true); // Add listeners for (int i = 0; i < numOfClasses; i++) addExampleButton[i].addActionListener(listener); trainButton.addActionListener(listener); overlayButton.addActionListener(listener); resultButton.addActionListener(listener); applyButton.addActionListener(listener); probimgButton.addActionListener(listener); loadDataButton.addActionListener(listener); saveDataButton.addActionListener(listener); addClassButton.addActionListener(listener); settingsButton.addActionListener(listener); // Training panel (left side of the GUI) trainingJPanel.setBorder(BorderFactory.createTitledBorder("Training")); GridBagLayout trainingLayout = new GridBagLayout(); GridBagConstraints trainingConstraints = new GridBagConstraints(); trainingConstraints.anchor = GridBagConstraints.NORTHWEST; trainingConstraints.fill = GridBagConstraints.HORIZONTAL; trainingConstraints.gridwidth = 1; trainingConstraints.gridheight = 1; trainingConstraints.gridx = 0; trainingConstraints.gridy = 0; trainingConstraints.insets = new Insets(5, 5, 6, 6); trainingJPanel.setLayout(trainingLayout); trainingJPanel.add(trainButton, trainingConstraints); trainingConstraints.gridy++; trainingJPanel.add(overlayButton, trainingConstraints); trainingConstraints.gridy++; trainingJPanel.add(resultButton, trainingConstraints); trainingConstraints.gridy++; // Options panel optionsJPanel.setBorder(BorderFactory.createTitledBorder("Options")); GridBagLayout optionsLayout = new GridBagLayout(); GridBagConstraints optionsConstraints = new GridBagConstraints(); optionsConstraints.anchor = GridBagConstraints.NORTHWEST; optionsConstraints.fill = GridBagConstraints.HORIZONTAL; optionsConstraints.gridwidth = 1; optionsConstraints.gridheight = 1; optionsConstraints.gridx = 0; optionsConstraints.gridy = 0; optionsConstraints.insets = new Insets(5, 5, 6, 6); optionsJPanel.setLayout(optionsLayout); optionsJPanel.add(applyButton, optionsConstraints); optionsConstraints.gridy++; optionsJPanel.add(probimgButton, optionsConstraints); optionsConstraints.gridy++; optionsJPanel.add(loadDataButton, optionsConstraints); optionsConstraints.gridy++; optionsJPanel.add(saveDataButton, optionsConstraints); optionsConstraints.gridy++; optionsJPanel.add(addClassButton, optionsConstraints); optionsConstraints.gridy++; optionsJPanel.add(settingsButton, optionsConstraints); optionsConstraints.gridy++; // Buttons panel (including training and options) GridBagLayout buttonsLayout = new GridBagLayout(); GridBagConstraints buttonsConstraints = new GridBagConstraints(); buttonsPanel.setLayout(buttonsLayout); buttonsConstraints.anchor = GridBagConstraints.NORTHWEST; buttonsConstraints.fill = GridBagConstraints.HORIZONTAL; buttonsConstraints.gridwidth = 1; buttonsConstraints.gridheight = 1; buttonsConstraints.gridx = 0; buttonsConstraints.gridy = 0; buttonsPanel.add(trainingJPanel, buttonsConstraints); buttonsConstraints.gridy++; buttonsPanel.add(optionsJPanel, buttonsConstraints); buttonsConstraints.gridy++; buttonsConstraints.insets = new Insets(5, 5, 6, 6); GridBagLayout layout = new GridBagLayout(); GridBagConstraints allConstraints = new GridBagConstraints(); all.setLayout(layout); allConstraints.anchor = GridBagConstraints.NORTHWEST; allConstraints.fill = GridBagConstraints.BOTH; allConstraints.gridwidth = 1; allConstraints.gridheight = 1; allConstraints.gridx = 0; allConstraints.gridy = 0; allConstraints.weightx = 0; allConstraints.weighty = 0; all.add(buttonsPanel, allConstraints); allConstraints.gridx++; allConstraints.weightx = 1; allConstraints.weighty = 1; all.add(canvas, allConstraints); allConstraints.gridx++; allConstraints.anchor = GridBagConstraints.NORTHEAST; allConstraints.weightx = 0; allConstraints.weighty = 0; all.add(annotationsPanel, allConstraints); GridBagLayout wingb = new GridBagLayout(); GridBagConstraints winc = new GridBagConstraints(); winc.anchor = GridBagConstraints.NORTHWEST; winc.fill = GridBagConstraints.BOTH; winc.weightx = 1; winc.weighty = 1; setLayout(wingb); add(all, winc); // Propagate all listeners for (Component p : new Component[] { all, buttonsPanel }) { for (KeyListener kl : getKeyListeners()) { p.addKeyListener(kl); } } addWindowListener(new WindowAdapter() { public void windowClosing(WindowEvent e) { //IJ.log("closing window"); // cleanup exec.shutdownNow(); for (int i = 0; i < numOfClasses; i++) addExampleButton[i].removeActionListener(listener); trainButton.removeActionListener(listener); overlayButton.removeActionListener(listener); resultButton.removeActionListener(listener); applyButton.removeActionListener(listener); probimgButton.removeActionListener(listener); loadDataButton.removeActionListener(listener); saveDataButton.removeActionListener(listener); addClassButton.removeActionListener(listener); settingsButton.removeActionListener(listener); // Set number of classes back to 2 numOfClasses = 2; } }); canvas.addComponentListener(new ComponentAdapter() { public void componentResized(ComponentEvent ce) { Rectangle r = canvas.getBounds(); canvas.setDstDimensions(r.width, r.height); } }); } /* public void changeDisplayImage(ImagePlus imp){ super.getImagePlus().setProcessor(imp.getProcessor()); super.getImagePlus().setTitle(imp.getTitle()); }private void saveFeatureStack() { // TODO Auto-generated method stub } */ /** * Repaint all panels */ public void repaintAll() { this.annotationsPanel.repaint(); getCanvas().repaint(); this.buttonsPanel.repaint(); this.all.repaint(); } /** * Add new segmentation class (new label and new list on the right side) */ public void addClass() { examples[numOfClasses] = new ArrayList<Roi>(); exampleList[numOfClasses] = new java.awt.List(5); exampleList[numOfClasses].setForeground(colors[numOfClasses]); exampleList[numOfClasses].addActionListener(listener); exampleList[numOfClasses].addItemListener(itemListener); addExampleButton[numOfClasses] = new JButton("Add to " + classLabels[numOfClasses]); annotationsConstraints.fill = GridBagConstraints.HORIZONTAL; annotationsConstraints.insets = new Insets(5, 5, 6, 6); boxAnnotation.setConstraints(addExampleButton[numOfClasses], annotationsConstraints); annotationsPanel.add(addExampleButton[numOfClasses]); annotationsConstraints.gridy++; annotationsConstraints.insets = new Insets(0, 0, 0, 0); boxAnnotation.setConstraints(exampleList[numOfClasses], annotationsConstraints); annotationsPanel.add(exampleList[numOfClasses]); annotationsConstraints.gridy++; // Add listener to the new button addExampleButton[numOfClasses].addActionListener(listener); // increase number of available classes numOfClasses++; //IJ.log("new number of classes = " + numOfClasses); repaintAll(); } } /** * Plugin run method */ public void run(String arg) { // trainingImage = IJ.openImage("testImages/i00000-1.tif"); //get current image if (null == WindowManager.getCurrentImage()) { trainingImage = IJ.openImage(); if (null == trainingImage) return; // user canceled open dialog } else trainingImage = new ImagePlus("Trainable Segmentation", WindowManager.getCurrentImage().getProcessor().duplicate()); if (Math.max(trainingImage.getWidth(), trainingImage.getHeight()) > 1024) if (!IJ.showMessageWithCancel("Warning", "At least one dimension of the image \n" + "is larger than 1024 pixels. \n" + "Feature stack creation and classifier training \n" + "might take some time depending on your computer.\n" + "Proceed?")) return; trainingImage.setProcessor("Trainable Segmentation", trainingImage.getProcessor().duplicate().convertToByte(true)); // Initialize feature stack (no features yet) featureStack = new FeatureStack(trainingImage); displayImage = new ImagePlus(); displayImage.setProcessor("Trainable Segmentation", trainingImage.getProcessor().duplicate()); ij.gui.Toolbar.getInstance().setTool(ij.gui.Toolbar.FREELINE); //Build GUI SwingUtilities.invokeLater(new Runnable() { public void run() { win = new CustomWindow(displayImage); win.pack(); } }); //trainingImage.getWindow().setVisible(false); } /** * Enable / disable buttons * @param s enabling flag */ private void setButtonsEnabled(Boolean s) { if (useGUI) { trainButton.setEnabled(s); overlayButton.setEnabled(s); resultButton.setEnabled(s); applyButton.setEnabled(s); probimgButton.setEnabled(s); loadDataButton.setEnabled(s); saveDataButton.setEnabled(s); addClassButton.setEnabled(s); settingsButton.setEnabled(s); for (int i = 0; i < numOfClasses; i++) { exampleList[i].setEnabled(s); addExampleButton[i].setEnabled(s); } } } /** * Add examples defined by the user to the corresponding list * @param i list index */ private void addExamples(int i) { //get selected pixels final Roi r = displayImage.getRoi(); if (null == r) { return; } displayImage.killRoi(); examples[i].add(r); exampleList[i].add("trace " + traceCounter[i]); traceCounter[i]++; drawExamples(); } /** * Draw the painted traces on the display image */ private void drawExamples() { for (int i = 0; i < numOfClasses; i++) { roiOverlay[i].setColor(colors[i]); final ArrayList<Roi> rois = new ArrayList<Roi>(); for (Roi r : examples[i]) { rois.add(r); //IJ.log("painted ROI: " + r + " in color "+ colors[i]); } roiOverlay[i].setRoi(rois); } displayImage.updateAndDraw(); } /** * Write current instances into an ARFF file * @param data set of instances * @param filename ARFF file name */ public void writeDataToARFF(Instances data, String filename) { try { BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(filename))); try { out.write(data.toString()); out.close(); } catch (IOException e) { IJ.showMessage("IOException"); } } catch (FileNotFoundException e) { IJ.showMessage("File not found!"); } } /** * Read ARFF file * @param filename ARFF file name * @return set of instances read from the file */ public Instances readDataFromARFF(String filename) { try { BufferedReader reader = new BufferedReader(new FileReader(filename)); try { Instances data = new Instances(reader); // setting class attribute data.setClassIndex(data.numAttributes() - 1); reader.close(); return data; } catch (IOException e) { IJ.showMessage("IOException"); } } catch (FileNotFoundException e) { IJ.showMessage("File not found!"); } return null; } /** * Create training instances out of the user markings * @return set of instances */ public Instances createTrainingInstances() { //IJ.log("create training instances: num of features = " + featureStack.getSize()); ArrayList<Attribute> attributes = new ArrayList<Attribute>(); for (int i = 1; i <= featureStack.getSize(); i++) { String attString = featureStack.getSliceLabel(i); attributes.add(new Attribute(attString)); } final ArrayList<String> classes = new ArrayList<String>(); int numOfInstances = 0; for (int i = 0; i < numOfClasses; i++) { // Do not add empty lists if (examples[i].size() > 0) classes.add(classLabels[i]); numOfInstances += examples[i].size(); } attributes.add(new Attribute("class", classes)); final Instances trainingData = new Instances("segment", attributes, numOfInstances); IJ.log("\nTraining input:"); // For all classes for (int l = 0; l < numOfClasses; l++) { int nl = 0; // Read all lists of examples for (int j = 0; j < examples[l].size(); j++) { Roi r = examples[l].get(j); // For polygon rois we get the list of points if (r instanceof PolygonRoi && r.getType() != Roi.FREEROI) { if (r.getStrokeWidth() == 1) { int[] x = r.getPolygon().xpoints; int[] y = r.getPolygon().ypoints; final int n = r.getPolygon().npoints; for (int i = 0; i < n; i++) { double[] values = new double[featureStack.getSize() + 1]; for (int z = 1; z <= featureStack.getSize(); z++) values[z - 1] = featureStack.getProcessor(z).getPixelValue(x[i], y[i]); values[featureStack.getSize()] = (double) l; trainingData.add(new DenseInstance(1.0, values)); // increase number of instances for this class nl++; } } else // For thicker lines, include also neighbors { final int width = (int) Math.round(r.getStrokeWidth()); FloatPolygon p = r.getFloatPolygon(); int n = p.npoints; double x1, y1; double x2 = p.xpoints[0] - (p.xpoints[1] - p.xpoints[0]); double y2 = p.ypoints[0] - (p.ypoints[1] - p.ypoints[0]); for (int i = 0; i < n; i++) { x1 = x2; y1 = y2; x2 = p.xpoints[i]; y2 = p.ypoints[i]; double dx = x2 - x1; double dy = y1 - y2; double length = (float) Math.sqrt(dx * dx + dy * dy); dx /= length; dy /= length; double x = x2 - dy * width / 2.0; double y = y2 - dx * width / 2.0; int n2 = width; do { if (x >= 0 && x < featureStack.getWidth() && y >= 0 && y < featureStack.getHeight()) { double[] values = new double[featureStack.getSize() + 1]; for (int z = 1; z <= featureStack.getSize(); z++) values[z - 1] = featureStack.getProcessor(z).getInterpolatedValue(x, y); values[featureStack.getSize()] = (double) l; trainingData.add(new DenseInstance(1.0, values)); // increase number of instances for this class nl++; } x += dy; y += dx; } while (--n2 > 0); } } } else // for the rest of rois we get ALL points inside the roi { final ShapeRoi shapeRoi = new ShapeRoi(r); final Rectangle rect = shapeRoi.getBounds(); final int lastX = rect.x + rect.width; final int lastY = rect.y + rect.height; for (int x = rect.x; x < lastX; x++) for (int y = rect.y; y < lastY; y++) if (shapeRoi.contains(x, y)) { double[] values = new double[featureStack.getSize() + 1]; for (int z = 1; z <= featureStack.getSize(); z++) values[z - 1] = featureStack.getProcessor(z).getPixelValue(x, y); values[featureStack.getSize()] = (double) l; trainingData.add(new DenseInstance(1.0, values)); // increase number of instances for this class nl++; } } } IJ.log("# of pixels selected as " + classLabels[l] + ": " + nl); } return trainingData; } /** * Train classifier with the current instances */ public void trainClassifier() { // Two list of examples need to be non empty int nonEmpty = 0; for (int i = 0; i < numOfClasses; i++) if (examples[i].size() > 0) nonEmpty++; if (nonEmpty < 2 && loadedTrainingData == null) { IJ.showMessage("Cannot train without at least 2 sets of examples!"); return; } // Disable buttons until the training has finished setButtonsEnabled(false); // Create feature stack if it was not created yet if (featureStack.isEmpty()) { IJ.showStatus("Creating feature stack..."); featureStack.updateFeaturesMT(); } IJ.showStatus("Training classifier..."); Instances data = null; if (nonEmpty < 2) IJ.log("Training from loaded data only..."); else { final long start = System.currentTimeMillis(); data = createTrainingInstances(); final long end = System.currentTimeMillis(); IJ.log("Creating training data took: " + (end - start) + "ms"); data.setClassIndex(data.numAttributes() - 1); } if (loadedTrainingData != null && data != null) { IJ.log("Merging data..."); for (int i = 0; i < loadedTrainingData.numInstances(); i++) data.add(loadedTrainingData.instance(i)); IJ.log("Finished"); } else if (data == null) { data = loadedTrainingData; IJ.log("Taking loaded data as only data..."); } IJ.showStatus("Training classifier..."); IJ.log("Training classifier..."); if (null == data) { IJ.log("WTF"); } // Train the classifier on the current data final long start = System.currentTimeMillis(); try { classifier.buildClassifier(data); } catch (Exception e) { IJ.showMessage(e.getMessage()); e.printStackTrace(); return; } final long end = System.currentTimeMillis(); final DecimalFormat df = new DecimalFormat("0.0000"); final String outOfBagError = (rf != null) ? ", out of bag error: " + df.format(rf.measureOutOfBagError()) : ""; IJ.log("Finished training in " + (end - start) + "ms" + outOfBagError); if (updateWholeData) { updateTestSet(); IJ.log("Test dataset updated (" + wholeImageData.numInstances() + " instances, " + wholeImageData.numAttributes() + " attributes)."); } IJ.log("Classifying whole image..."); classifiedImage = applyClassifier(wholeImageData, trainingImage.getWidth(), trainingImage.getHeight(), Runtime.getRuntime().availableProcessors()); IJ.log("Finished segmentation of whole image."); if (useGUI) { overlayButton.setEnabled(true); resultButton.setEnabled(true); applyButton.setEnabled(true); probimgButton.setEnabled(true); showColorOverlay = false; toggleOverlay(); setButtonsEnabled(true); } //featureStack.show(); } /** * Update whole data set with current number of classes and features */ private void updateTestSet() { IJ.showStatus("Reading whole image data..."); long start = System.currentTimeMillis(); ArrayList<String> classNames = null; if (loadedTrainingData != null) classNames = loadedClassNames; else { classNames = new ArrayList<String>(); for (int i = 0; i < numOfClasses; i++) if (examples[i].size() > 0) classNames.add(classLabels[i]); } wholeImageData = featureStack.createInstances(classNames); long end = System.currentTimeMillis(); IJ.log("Creating whole image data took: " + (end - start) + "ms"); wholeImageData.setClassIndex(wholeImageData.numAttributes() - 1); updateWholeData = false; } /** * Apply current classifier to set of instances * @param data set of instances * @param w image width * @param h image height * @param numThreads number of threads to create * @param prob create a multi-channel probability image * @return result image */ public ImagePlus applyClassifier(final Instances data, final int w, final int h, final int numThreads) { IJ.log("Applying classifier in " + numThreads + " threads..."); IJ.showStatus("Classifying image..."); final long start = System.currentTimeMillis(); final ExecutorService exe = Executors.newFixedThreadPool(numThreads); final double[][] results = new double[numThreads][]; final Instances[] partialData = new Instances[numThreads]; final int partialSize = data.numInstances() / numThreads; Future<double[]> fu[] = new Future[numThreads]; final AtomicInteger counter = new AtomicInteger(); //IJ.log("Dividing dataset into subsets for parallel execution..."); for (int i = 0; i < numThreads; i++) { if (i == numThreads - 1) partialData[i] = new Instances(data, i * partialSize, data.numInstances() - i * partialSize); else partialData[i] = new Instances(data, i * partialSize, partialSize); fu[i] = exe.submit(classifyIntances(partialData[i], classifier, counter)); } ScheduledExecutorService monitor = Executors.newScheduledThreadPool(1); ScheduledFuture task = monitor.scheduleWithFixedDelay(new Runnable() { public void run() { IJ.showProgress(counter.get(), data.numInstances()); } }, 0, 1, TimeUnit.SECONDS); //IJ.log("Waiting for jobs..."); // Join threads for (int i = 0; i < numThreads; i++) { try { results[i] = fu[i].get(); } catch (InterruptedException e) { IJ.log("Interruption exception"); e.printStackTrace(); return null; } catch (ExecutionException e) { IJ.log("Execution exception"); e.printStackTrace(); return null; } finally { exe.shutdown(); task.cancel(true); monitor.shutdownNow(); IJ.showProgress(1); } } exe.shutdown(); // Create final array double[] classificationResult = new double[data.numInstances()]; for (int i = 0; i < numThreads; i++) System.arraycopy(results[i], 0, classificationResult, i * partialSize, results[i].length); IJ.showProgress(1.0); final long end = System.currentTimeMillis(); IJ.log("Classifying whole image data took: " + (end - start) + "ms"); IJ.showStatus("Displaying result..."); final ImageProcessor classifiedImageProcessor = new FloatProcessor(w, h, classificationResult); classifiedImageProcessor.convertToByte(true); ImagePlus classImg = new ImagePlus("Classification result", classifiedImageProcessor); return classImg; } /** * Apply current classifier to set of instances to get a probability * distribution. * * @param data set of instances * @param w image width * @param h image height * @param numThreads number of threads to be used * @return result image */ public ImagePlus[] getClassifierDistribution(final Instances data, int w, int h, final int numThreads) { IJ.log("Calculating probability distribution in " + numThreads + " threads..."); final long start = System.currentTimeMillis(); final ExecutorService exe = Executors.newFixedThreadPool(numThreads); final double[][][] results = new double[numThreads][][]; final Instances[] partialData = new Instances[numThreads]; final int partialSize = data.numInstances() / numThreads; Future<double[][]> fu[] = new Future[numThreads]; final AtomicInteger counter = new AtomicInteger(); //IJ.log("Dividing dataset into subsets for parallel execution..."); for (int i = 0; i < numThreads; i++) { if (i == numThreads - 1) partialData[i] = new Instances(data, i * partialSize, data.numInstances() - i * partialSize); else partialData[i] = new Instances(data, i * partialSize, partialSize); fu[i] = exe.submit(probFromInstances(partialData[i], classifier, counter)); } ScheduledExecutorService monitor = Executors.newScheduledThreadPool(1); ScheduledFuture task = monitor.scheduleWithFixedDelay(new Runnable() { public void run() { IJ.showProgress(counter.get(), data.numInstances()); } }, 0, 1, TimeUnit.SECONDS); //IJ.log("Waiting for jobs..."); // Join threads for (int i = 0; i < numThreads; i++) { try { results[i] = fu[i].get(); } catch (InterruptedException e) { IJ.log("Interruption exception"); e.printStackTrace(); return null; } catch (ExecutionException e) { IJ.log("Execution exception"); e.printStackTrace(); return null; } finally { exe.shutdown(); task.cancel(true); monitor.shutdownNow(); IJ.showProgress(1); } } exe.shutdown(); // Create final array double[][] probDistribution = new double[numOfClasses][data.numInstances()]; for (int c = 0; c < numOfClasses; c++) for (int i = 0; i < numThreads; i++) System.arraycopy(results[i][c], 0, probDistribution[c], i * partialSize, results[i][c].length); IJ.showProgress(1.0); final long end = System.currentTimeMillis(); IJ.log("Probability distribution for whole image data took: " + (end - start) + "ms"); IJ.showStatus("Displaying result..."); ImagePlus[] classImgs = new ImagePlus[numOfClasses]; for (int c = 0; c < numOfClasses; c++) { final ImageProcessor classifiedImageProcessor = new FloatProcessor(w, h, probDistribution[c]); classifiedImageProcessor.convertToByte(true); classImgs[c] = new ImagePlus("Classification result", classifiedImageProcessor); } return classImgs; } /** * Classify instance concurrently * @param data set of instances to classify * @param classifier current classifier * @return classification result */ private static Callable<double[]> classifyIntances(final Instances data, final AbstractClassifier classifier, final AtomicInteger counter) { return new Callable<double[]>() { public double[] call() { final int numInstances = data.numInstances(); final double[] classificationResult = new double[numInstances]; for (int i = 0; i < numInstances; i++) { try { if (0 == i % 4000) counter.addAndGet(4000); classificationResult[i] = classifier.classifyInstance(data.instance(i)); } catch (Exception e) { IJ.showMessage("Could not apply Classifier!"); e.printStackTrace(); return null; } } return classificationResult; } }; } /** * Get probability distribution for classified instance concurrently * @param data classified set of instances * @param classifier current classifier * @return classification result */ private static Callable<double[][]> probFromInstances(final Instances data, final AbstractClassifier classifier, final AtomicInteger counter) { return new Callable<double[][]>() { public double[][] call() { final int numInstances = data.numInstances(); final int numOfClasses = data.numClasses(); final double[][] probabilityDistribution = new double[numOfClasses][numInstances]; for (int i = 0; i < numInstances; i++) { try { if (0 == i % 4000) counter.addAndGet(4000); double[] probs = classifier.distributionForInstance(data.instance(i)); for (int c = 0; c < numOfClasses; c++) probabilityDistribution[c][i] = probs[c]; } catch (Exception e) { IJ.showMessage("Could not apply Classifier!"); e.printStackTrace(); return null; } } return probabilityDistribution; } }; } /** * Toggle between overlay and original image with markings */ void toggleOverlay() { showColorOverlay = !showColorOverlay; //IJ.log("toggle overlay to: " + showColorOverlay); if (showColorOverlay) { ImageProcessor overlay = classifiedImage.getProcessor().duplicate(); //classifiedImage.show(); double shift = 255.0 / MAX_NUM_CLASSES; overlay.multiply(shift + 1); overlay = overlay.convertToByte(false); overlay.setColorModel(overlayLUT); ///new ImagePlus("Overlay", overlay).show(); resultOverlay.setImage(overlay); } else resultOverlay.setImage(null); displayImage.updateAndDraw(); //drawExamples(); } /** * Select a list and deselect the others * @param e item event (originated by a list) * @param i list index */ void listSelected(final ItemEvent e, final int i) { drawExamples(); displayImage.setColor(Color.YELLOW); for (int j = 0; j < numOfClasses; j++) { if (j == i) { final Roi newRoi = examples[i].get(exampleList[i].getSelectedIndex()); // Set selected trace as current ROI newRoi.setImage(displayImage); displayImage.setRoi(newRoi); } else exampleList[j].deselect(exampleList[j].getSelectedIndex()); } displayImage.updateAndDraw(); } /** * Delete one of the ROIs * * @param e action event */ void deleteSelected(final ActionEvent e) { for (int i = 0; i < numOfClasses; i++) if (e.getSource() == exampleList[i]) { //delete item from ROI int index = exampleList[i].getSelectedIndex(); // kill Roi from displayed image if (displayImage.getRoi().equals(examples[i].get(index))) displayImage.killRoi(); examples[i].remove(index); //delete item from list exampleList[i].remove(index); } drawExamples(); } /** * Display the whole image after classification */ void showClassificationImage() { ImagePlus resultImage = new ImagePlus("classification result", classifiedImage.getProcessor().convertToByte(true).duplicate()); resultImage.show(); } /** * Apply classifier to test data */ public void applyClassifierToTestData() { // array of files to process File[] imageFiles; // create a file chooser for the image files JFileChooser fileChooser = new JFileChooser( "/home/jan/workspace/mpi/yolk/data/downsampled/2010-04-02 histon"); fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); fileChooser.setMultiSelectionEnabled(true); // get selected files or abort if no file has been selected int returnVal = fileChooser.showOpenDialog(null); if (returnVal == JFileChooser.APPROVE_OPTION) { imageFiles = fileChooser.getSelectedFiles(); } else { return; } boolean showResults = true; boolean storeResults = false; if (imageFiles.length >= 3) { int decision = JOptionPane.showConfirmDialog(null, "You decided to process three or more image files. Do you want the results to be stored on the disk instead of opening them in Fiji?", "Save results?", JOptionPane.YES_NO_OPTION); if (decision == JOptionPane.YES_OPTION) { showResults = false; storeResults = true; } } final int numProcessors = Runtime.getRuntime().availableProcessors(); IJ.log("Processing " + imageFiles.length + " image files in " + numProcessors + " threads...."); setButtonsEnabled(false); Thread[] threads = new Thread[numProcessors]; class ImageProcessingThread extends Thread { final int numThread; final int numProcessors; final int numFurtherThreads; final File[] imageFiles; final boolean storeResults; final boolean showResults; public ImageProcessingThread(int numThread, int numProcessors, int numFurtherThreads, File[] imageFiles, boolean storeResults, boolean showResults) { this.numThread = numThread; this.numProcessors = numProcessors; this.numFurtherThreads = numFurtherThreads; this.imageFiles = imageFiles; this.storeResults = storeResults; this.showResults = showResults; } public void run() { for (int i = numThread; i < imageFiles.length; i += numProcessors) { File file = imageFiles[i]; ImagePlus testImage = IJ.openImage(file.getPath()); IJ.log("Processing image " + file.getName() + " in thread " + numThread); ImagePlus segmentation = applyClassifierToTestImage(testImage, numFurtherThreads); if (showResults) { segmentation.show(); testImage.show(); } if (storeResults) { IJ.save(segmentation, file.getPath() + "seg.tif"); segmentation.close(); testImage.close(); } } } } final int numFurtherThreads = Math.max(1, (numProcessors - imageFiles.length) / imageFiles.length + 1); // start threads for (int i = 0; i < numProcessors; i++) { threads[i] = new ImageProcessingThread(i, numProcessors, numFurtherThreads, imageFiles, storeResults, showResults); threads[i].start(); } // join all threads for (Thread thread : threads) { try { thread.join(); } catch (InterruptedException e) { } } setButtonsEnabled(true); } /** * Apply classifier to a set of images and create a multi-channel * probability distribution image. */ public void createProbImgFromTestData() { // array of files to process File[] imageFiles; // create a file chooser for the image files JFileChooser fileChooser = new JFileChooser( "/home/jan/workspace/mpi/yolk/data/downsampled/2010-04-02 histon"); fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); fileChooser.setMultiSelectionEnabled(true); // get selected files or abort if no file has been selected int returnVal = fileChooser.showOpenDialog(null); if (returnVal == JFileChooser.APPROVE_OPTION) { imageFiles = fileChooser.getSelectedFiles(); } else { return; } boolean showResults = true; boolean storeResults = false; if (imageFiles.length >= 3) { int decision = JOptionPane.showConfirmDialog(null, "You decided to process three or more image files. Do you want the results to be stored on the disk instead of opening them in Fiji?", "Save results?", JOptionPane.YES_NO_OPTION); if (decision == JOptionPane.YES_OPTION) { showResults = false; storeResults = true; } } final int numProcessors = Runtime.getRuntime().availableProcessors(); IJ.log("Processing " + imageFiles.length + " image files in " + numProcessors + " threads...."); setButtonsEnabled(false); Thread[] threads = new Thread[numProcessors]; class ImageProcessingThread extends Thread { final int numThread; final int numProcessors; final int numFurtherThreads; final File[] imageFiles; final boolean storeResults; final boolean showResults; public ImageProcessingThread(int numThread, int numProcessors, int numFurtherThreads, File[] imageFiles, boolean storeResults, boolean showResults) { this.numThread = numThread; this.numProcessors = numProcessors; this.numFurtherThreads = numFurtherThreads; this.imageFiles = imageFiles; this.storeResults = storeResults; this.showResults = showResults; } public void run() { for (int i = numThread; i < imageFiles.length; i += numProcessors) { File file = imageFiles[i]; ImagePlus testImage = IJ.openImage(file.getPath()); IJ.log("Processing image " + file.getName() + " in thread " + numThread); ImagePlus probImage = createProbImgFromTestData(testImage, numFurtherThreads); if (showResults) { probImage.show(); testImage.show(); } if (storeResults) { IJ.save(probImage, file.getPath() + "prob.tif"); probImage.close(); testImage.close(); } } } } final int numFurtherThreads = Math.max(1, (numProcessors - imageFiles.length) / imageFiles.length + 1); // start threads for (int i = 0; i < numProcessors; i++) { threads[i] = new ImageProcessingThread(i, numProcessors, numFurtherThreads, imageFiles, storeResults, showResults); threads[i].start(); } // join all threads for (Thread thread : threads) { try { thread.join(); } catch (InterruptedException e) { } } setButtonsEnabled(true); } /** * Apply current classifier to image * * @param testImage test image (2D single image or stack) * @return result image (classification) */ public ImagePlus applyClassifierToTestImage(final ImagePlus testImage, final int numThreads) { IJ.log("Processing slices of " + testImage.getTitle() + " in " + numThreads + " threads..."); // Set proper class names (skip empty list ones) ArrayList<String> classNames = new ArrayList<String>(); if (null == loadedClassNames) { for (int i = 0; i < numOfClasses; i++) if (examples[i].size() > 0) classNames.add(classLabels[i]); } else classNames = loadedClassNames; final ImagePlus[] classifiedSlices = new ImagePlus[testImage.getStackSize()]; class ApplyClassifierThread extends Thread { final int startSlice; final int numSlices; final int numFurtherThreads; final ArrayList<String> classNames; public ApplyClassifierThread(int startSlice, int numSlices, int numFurtherThreads, ArrayList<String> classNames) { this.startSlice = startSlice; this.numSlices = numSlices; this.numFurtherThreads = numFurtherThreads; this.classNames = classNames; } public void run() { for (int i = startSlice; i < startSlice + numSlices; i++) { final ImagePlus testSlice = new ImagePlus(testImage.getImageStack().getSliceLabel(i), testImage.getImageStack().getProcessor(i).convertToByte(true)); // Create feature stack for test image IJ.showStatus("Creating features..."); IJ.log("Creating features for slice " + i + "..."); final FeatureStack testImageFeatures = new FeatureStack(testSlice); // Use the same features as the current classifier testImageFeatures.setEnabledFeatures(featureStack.getEnabledFeatures()); testImageFeatures.updateFeaturesMT(); final Instances testData = testImageFeatures.createInstances(classNames); final ImagePlus testClassImage = applyClassifier(testData, testSlice.getWidth(), testSlice.getHeight(), numFurtherThreads); testClassImage.setTitle("classified_" + testSlice.getTitle()); testClassImage.setProcessor(testClassImage.getProcessor().convertToByte(true).duplicate()); classifiedSlices[i - 1] = testClassImage; } } } final int numFurtherThreads = Math.max(1, (numThreads - testImage.getStackSize()) / testImage.getStackSize() + 1); final ApplyClassifierThread[] threads = new ApplyClassifierThread[numThreads]; int numSlices = testImage.getStackSize() / numThreads; for (int i = 0; i < numThreads; i++) { int startSlice = i * numSlices + 1; // last thread takes all the remaining slices if (i == numThreads - 1) numSlices = testImage.getStackSize() - (numThreads - 1) * (testImage.getStackSize() / numThreads); IJ.log("Starting thread " + i + " processing " + numSlices + " slices, starting with " + startSlice); threads[i] = new ApplyClassifierThread(startSlice, numSlices, numFurtherThreads, classNames); threads[i].start(); } // create classified image final ImageStack classified = new ImageStack(testImage.getWidth(), testImage.getHeight()); // join threads for (Thread thread : threads) try { thread.join(); } catch (InterruptedException e) { e.printStackTrace(); } // assamble classified image for (int i = 0; i < testImage.getStackSize(); i++) classified.addSlice(classifiedSlices[i].getTitle(), classifiedSlices[i].getProcessor()); return new ImagePlus("Classification result", classified); } /** * Create multi-channel probability distribution image from image * * @param testImage test image (2D single image or stack) * @param numThreads number of threads to be used * @return result image (probability distribution) */ public ImagePlus createProbImgFromTestData(final ImagePlus testImage, final int numThreads) { IJ.log("Processing slices of " + testImage.getTitle() + " in " + numThreads + " threads..."); // Set proper class names (skip empty list ones) ArrayList<String> classNames = new ArrayList<String>(); if (null == loadedClassNames) { for (int i = 0; i < numOfClasses; i++) if (examples[i].size() > 0) classNames.add(classLabels[i]); } else classNames = loadedClassNames; final int numFurtherThreads = Math.max(1, (numThreads - testImage.getStackSize()) / testImage.getStackSize() + 1); final ImagePlus[] probSlices = new ImagePlus[testImage.getStackSize() * numOfClasses]; class ProbImageThread extends Thread { final int startSlice; final int numSlices; final int numFurtherThreads; final ArrayList<String> classNames; public ProbImageThread(int startSlice, int numSlices, int numFurtherThreads, ArrayList<String> classNames) { this.startSlice = startSlice; this.numSlices = numSlices; this.numFurtherThreads = numFurtherThreads; this.classNames = classNames; } public void run() { for (int i = startSlice; i < startSlice + numSlices; i++) { final ImagePlus testSlice = new ImagePlus(testImage.getImageStack().getSliceLabel(i), testImage.getImageStack().getProcessor(i).convertToByte(true)); // Create feature stack for test image IJ.showStatus("Creating features for test image..."); IJ.log("Creating features for test image " + i + "..."); final FeatureStack testImageFeatures = new FeatureStack(testSlice); // Use the same features as the current classifier testImageFeatures.setEnabledFeatures(featureStack.getEnabledFeatures()); testImageFeatures.updateFeaturesST(); final Instances testData = testImageFeatures.createInstances(classNames); testData.setClassIndex(testData.numAttributes() - 1); final ImagePlus[] testClassImages = getClassifierDistribution(testData, testSlice.getWidth(), testSlice.getHeight(), numFurtherThreads); for (int c = 0; c < numOfClasses; c++) probSlices[(i - 1) * numOfClasses + c] = testClassImages[c]; } } } final ProbImageThread[] threads = new ProbImageThread[numThreads]; int numSlices = testImage.getStackSize() / numThreads; for (int i = 0; i < numThreads; i++) { int startSlice = i * numSlices + 1; // last thread takes all the remaining slices if (i == numThreads - 1) numSlices = testImage.getStackSize() - (numThreads - 1) * (testImage.getStackSize() / numThreads); IJ.log("Starting thread " + i + " processing " + numSlices + " slices, starting with " + startSlice); threads[i] = new ProbImageThread(startSlice, numSlices, numFurtherThreads, classNames); threads[i].start(); } // create probability image final ImageStack probStack = new ImageStack(testImage.getWidth(), testImage.getHeight()); // join all threads for (Thread thread : threads) { try { thread.join(); } catch (InterruptedException e) { } } // assemble probability image for (int i = 0; i < testImage.getStackSize() * numOfClasses; i++) probStack.addSlice(probSlices[i].getTitle(), probSlices[i].getProcessor().convertToByte(true).duplicate()); ImagePlus probImage = new ImagePlus("Class probability image", probStack); probImage.setDimensions(numOfClasses, testImage.getNSlices(), testImage.getNFrames()); probImage.setOpenAsHyperStack(true); return probImage; } /** * Load previously saved data */ public void loadTrainingData() { OpenDialog od = new OpenDialog("Choose data file", ""); if (od.getFileName() == null) return; loadTrainingData(od.getDirectory() + od.getFileName()); } /** * Save training model into a file */ public void saveTrainingData() { boolean examplesEmpty = true; for (int i = 0; i < numOfClasses; i++) if (examples[i].size() > 0) { examplesEmpty = false; break; } if (examplesEmpty && loadedTrainingData == null) { IJ.showMessage("There is no data to save"); return; } if (featureStack.getSize() < 2) { setButtonsEnabled(false); featureStack.updateFeaturesMT(); setButtonsEnabled(true); } Instances data = createTrainingInstances(); data.setClassIndex(data.numAttributes() - 1); if (null != loadedTrainingData && null != data) { IJ.log("merging data"); for (int i = 0; i < loadedTrainingData.numInstances(); i++) { // IJ.log("" + i) data.add(loadedTrainingData.instance(i)); } IJ.log("Finished"); } else if (null == data) data = loadedTrainingData; SaveDialog sd = new SaveDialog("Choose save file", "data", ".arff"); if (sd.getFileName() == null) return; IJ.log("Writing training data: " + data.numInstances() + " instances..."); writeDataToARFF(data, sd.getDirectory() + sd.getFileName()); IJ.log("Wrote training data: " + sd.getDirectory() + sd.getFileName()); } /** * Add new class in the panel (up to MAX_NUM_CLASSES) */ private void addNewClass() { if (numOfClasses == MAX_NUM_CLASSES) { IJ.showMessage("Trainable Segmentation", "Sorry, maximum number of classes has been reached"); return; } //IJ.log("Adding new class..."); String inputName = JOptionPane.showInputDialog("Please input a new label name"); if (null == inputName) return; if (null == inputName || 0 == inputName.length()) { IJ.error("Invalid name for class"); return; } inputName = inputName.trim(); if (0 == inputName.toLowerCase().indexOf("add to ")) inputName = inputName.substring(7); // Add new name to the list of labels classLabels[numOfClasses] = inputName; // Add new class label and list win.addClass(); repaintWindow(); // Force whole data to be updated updateWholeData = true; } /** * Repaint whole window */ private void repaintWindow() { // Repaint window SwingUtilities.invokeLater(new Runnable() { public void run() { win.invalidate(); win.validate(); win.repaint(); } }); } /** * Show advanced settings dialog * * @return false when canceled */ public boolean showSettingsDialog() { GenericDialogPlus gd = new GenericDialogPlus("Segmentation settings"); final boolean[] oldEnableFeatures = this.featureStack.getEnabledFeatures(); gd.addMessage("Training features:"); final int rows = (int) Math.round(FeatureStack.availableFeatures.length / 2.0); gd.addCheckboxGroup(rows, 2, FeatureStack.availableFeatures, oldEnableFeatures); if (loadedTrainingData != null) { final Vector<Checkbox> v = gd.getCheckboxes(); for (Checkbox c : v) c.setEnabled(false); gd.addMessage("WARNING: no features are selectable while using loaded data"); } gd.addMessage("General options:"); gd.addMessage("Fast Random Forest settings:"); gd.addNumericField("Number of trees:", numOfTrees, 0); gd.addNumericField("Random features", randomFeatures, 0); gd.addMessage("Class names:"); for (int i = 0; i < numOfClasses; i++) gd.addStringField("Class " + (i + 1), classLabels[i], 15); gd.addMessage("Advanced options:"); gd.addButton("Save feature stack", new ButtonListener("Select location to save feature stack", featureStack)); gd.addSlider("Result overlay opacity", 0, 100, overlayOpacity); gd.addHelp("http://fiji.sc/wiki/Trainable_Segmentation_Plugin"); gd.showDialog(); if (gd.wasCanceled()) return false; final int numOfFeatures = FeatureStack.availableFeatures.length; final boolean[] newEnableFeatures = new boolean[numOfFeatures]; boolean featuresChanged = false; // Read checked features and check if any of them changed for (int i = 0; i < numOfFeatures; i++) { newEnableFeatures[i] = gd.getNextBoolean(); if (newEnableFeatures[i] != oldEnableFeatures[i]) featuresChanged = true; } // Read fast random forest parameters and check if changed final int newNumTrees = (int) gd.getNextNumber(); final int newRandomFeatures = (int) gd.getNextNumber(); boolean classNameChanged = false; for (int i = 0; i < numOfClasses; i++) { String s = gd.getNextString(); if (null == s || 0 == s.length()) { IJ.log("Invalid name for class " + (i + 1)); continue; } s = s.trim(); if (!s.equals(classLabels[i])) { if (0 == s.toLowerCase().indexOf("add to ")) s = s.substring(7); classLabels[i] = s; classNameChanged = true; addExampleButton[i].setText("Add to " + classLabels[i]); } } // Update result overlay alpha final int newOpacity = (int) gd.getNextNumber(); if (newOpacity != overlayOpacity) { overlayOpacity = newOpacity; overlayAlpha = AlphaComposite.getInstance(AlphaComposite.SRC_OVER, overlayOpacity / 100f); resultOverlay.setComposite(overlayAlpha); if (showColorOverlay) displayImage.updateAndDraw(); } // If there is a change in the class names, // the data set (instances) must be updated. if (classNameChanged) { updateWholeData = true; // Pack window to update buttons win.pack(); } // Update random forest if necessary if (newNumTrees != numOfTrees || newRandomFeatures != randomFeatures) updateClassifier(newNumTrees, newRandomFeatures); // Update feature stack if necessary if (featuresChanged) { this.setButtonsEnabled(false); this.featureStack.setEnabledFeatures(newEnableFeatures); this.featureStack.updateFeaturesMT(); this.setButtonsEnabled(true); // Force whole data to be updated updateWholeData = true; } return true; } /** * Button listener class to handle the button action from the * settings dialog */ static class ButtonListener implements ActionListener { String title; TextField text; FeatureStack featureStack; public ButtonListener(String title, FeatureStack featureStack) { this.title = title; this.featureStack = featureStack; } public void actionPerformed(ActionEvent e) { if (featureStack.isEmpty()) { IJ.error("Error", "The feature stack has not been initialized yet, please train first."); return; } SaveDialog sd = new SaveDialog(title, "feature-stack", ".tif"); final String dir = sd.getDirectory(); final String filename = sd.getFileName(); if (null == dir || null == filename) return; if (false == this.featureStack.saveStackAsTiff(dir + filename)) { IJ.error("Error", "Feature stack could not be saved"); return; } IJ.log("Feature stack saved as " + dir + filename); } } /** * Update fast random forest classifier with new values * * @param newNumTrees new number of trees * @param newRandomFeatures new number of random features per tree * @return false if error */ private boolean updateClassifier(int newNumTrees, int newRandomFeatures) { if (newNumTrees < 1 || newRandomFeatures < 1) return false; numOfTrees = newNumTrees; randomFeatures = newRandomFeatures; rf.setNumTrees(numOfTrees); rf.setNumFeatures(randomFeatures); return true; } /////////////////////////////////////////////////////////////////////////// // Library style methods ////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////// /** * No GUI constructor * * @param trainingImage input image */ public Trainable_Segmentation(ImagePlus trainingImage) { // no GUI this.useGUI = false; this.trainingImage = trainingImage; for (int i = 0; i < numOfClasses; i++) examples[i] = new ArrayList<Roi>(); // Initialization of Fast Random Forest classifier rf = new FastRandomForest(); rf.setNumTrees(numOfTrees); //this is the default that Breiman suggests //rf.setNumFeatures((int) Math.round(Math.sqrt(featureStack.getSize()))); //but this seems to work better rf.setNumFeatures(randomFeatures); rf.setSeed(123); classifier = rf; // Initialize feature stack (no features yet) featureStack = new FeatureStack(trainingImage); } /** * Load training data (no GUI) * * @param pathname complete path name of the training data file (.arff) * @return false if error */ public boolean loadTrainingData(String pathname) { IJ.log("Loading data from " + pathname + "..."); loadedTrainingData = readDataFromARFF(pathname); // Check the features that were used in the loaded data Enumeration<Attribute> attributes = loadedTrainingData.enumerateAttributes(); final int numFeatures = FeatureStack.availableFeatures.length; boolean[] usedFeatures = new boolean[numFeatures]; while (attributes.hasMoreElements()) { final Attribute a = attributes.nextElement(); for (int i = 0; i < numFeatures; i++) if (a.name().startsWith(FeatureStack.availableFeatures[i])) usedFeatures[i] = true; } // Check if classes match Attribute classAttribute = loadedTrainingData.classAttribute(); Enumeration<String> classValues = classAttribute.enumerateValues(); // Update list of names of loaded classes loadedClassNames = new ArrayList<String>(); int j = 0; while (classValues.hasMoreElements()) { final String className = classValues.nextElement().trim(); loadedClassNames.add(className); IJ.log("Read class name: " + className); if (!className.equals(this.classLabels[j])) { String s = classLabels[0]; for (int i = 1; i < numOfClasses; i++) s = s.concat(", " + classLabels[i]); IJ.error("ERROR: Loaded classes and current classes do not match!\nExpected: " + s); loadedTrainingData = null; return false; } j++; } if (j != numOfClasses) { IJ.error("ERROR: Loaded number of classes and current number do not match!"); loadedTrainingData = null; return false; } IJ.log("Loaded data: " + loadedTrainingData.numInstances() + " instances, " + loadedTrainingData.numAttributes() + " attributes."); boolean featuresChanged = false; final boolean[] oldEnableFeatures = this.featureStack.getEnabledFeatures(); // Read checked features and check if any of them chasetButtonsEnablednged for (int i = 0; i < numFeatures; i++) { if (usedFeatures[i] != oldEnableFeatures[i]) featuresChanged = true; } // Update feature stack if necessary if (featuresChanged) { this.setButtonsEnabled(false); this.featureStack.setEnabledFeatures(usedFeatures); this.featureStack.updateFeaturesMT(); this.setButtonsEnabled(true); // Force whole data to be updated updateWholeData = true; } return true; } /** * Get current classification result * @return classified image */ public ImagePlus getClassifiedImage() { return classifiedImage; } }