org.esa.snap.classification.gpf.BaseClassifier.java Source code

Java tutorial

Introduction

Here is the source code for org.esa.snap.classification.gpf.BaseClassifier.java

Source

/*
 * Copyright (C) 2016 by Array Systems Computing Inc. http://www.array.ca
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation; either version 3 of the License, or (at your option)
 * any later version.
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, see http://www.gnu.org/licenses/
 */
package org.esa.snap.classification.gpf;

import be.abeel.util.Pair;
import com.bc.ceres.core.ProgressMonitor;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.thoughtworks.xstream.XStream;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.DenseInstance;
import net.sf.javaml.core.Instance;
import org.esa.snap.core.datamodel.Band;
import org.esa.snap.core.datamodel.IndexCoding;
import org.esa.snap.core.datamodel.Product;
import org.esa.snap.core.datamodel.ProductData;
import org.esa.snap.core.datamodel.ProductNodeGroup;
import org.esa.snap.core.datamodel.VectorDataNode;
import org.esa.snap.core.datamodel.VirtualBand;
import org.esa.snap.core.dataop.downloadable.StatusProgressMonitor;
import org.esa.snap.core.gpf.Operator;
import org.esa.snap.core.gpf.OperatorException;
import org.esa.snap.core.gpf.Tile;
import org.esa.snap.core.util.ProductUtils;
import org.esa.snap.core.util.SystemUtils;
import org.esa.snap.engine_utilities.gpf.OperatorUtils;
import org.esa.snap.engine_utilities.gpf.StackUtils;
import org.esa.snap.engine_utilities.gpf.ThreadManager;
import org.esa.snap.engine_utilities.gpf.TileIndex;
import org.esa.snap.engine_utilities.util.VectorUtils;
import org.opengis.referencing.crs.CoordinateReferenceSystem;

import java.awt.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Base class for classifiers.
 */
public abstract class BaseClassifier implements SupervisedClassifier {

    private final ClassifierParams params;
    private final ClassifierReport classifierReport;

    private Product maskProduct = null;
    private Product[] featureProducts;
    private int sourceImageWidth;
    private int sourceImageHeight;
    private boolean classifierTrained = false;

    private Band labelBand = null; // target

    // E.g., for Random Forest, this is the number of trees that vote for the label over the total number of trees
    private Band confidenceBand = null; // target

    private Band trainingSetMaskBand = null; // source
    private double maskNoDataValue = Double.NaN;
    private double maxClassValue = Double.NaN;

    private FeatureInfo[] featureInfoList;

    protected Classifier mlClassifier;

    private final static String LabelBandName = "LabeledClasses"; // target
    private final static String ConfidenceBandName = "Confidence"; // target
    public final static String VectorNodeNameLabelSource = "VectorNodeName";

    private boolean doLoadClassifier;

    private VectorDataNode[] polygonVectorDataNodes;
    private Map<VectorDataNode, Integer> polygonVectorDataNodeToVectorIndex;
    private Map<Integer, String> classLabelMap;
    private Map<String, Integer> labelClassMap;
    private boolean useVectorNodeNameAsLabel;

    private ClassifierDescriptor loadedClassifierDescriptor = null; // only for when doLoadClassifier is true

    private final static int INT_NO_DATA_VALUE = -1;
    private final static double DOUBLE_NO_DATA_VALUE = Double.NaN;
    private final static int NOT_IN_POLYGON = -1;

    public final static String CLASSIFIER_FILE_EXTENSION = ".class";
    public final static String CLASSIFIER_USER_INFO_FILE_EXTENSION = ".xml";
    public final static String CLASSIFIER_ROOT_FOLDER = "classifiers";

    private final static int minPowerSetSize = 1;
    private final static int maxPowerSetSize = 30;
    private double topClassifierPercent = 0;
    private String topClassifierName;
    private FeatureInfo[] topFeatureInfoList;

    private static final String[] excludedBands = new String[] { "lat_band", "long_band", "flags" };

    public static class ClassifierParams {
        private final String classifierType;
        private final String productSuffix;
        private final Product[] sourceProducts;
        private final int numTrainSamples;
        private double minClassValue;
        private double classValStepSize;
        private int classLevels;
        private String savedClassifierName;
        private boolean doClassValQuantization;
        private final boolean trainOnRaster;
        private final String[] trainingBands;
        private String[] trainingVectors;
        private String labelSource; // vector node name or attribute name
        private String[] featureBands;
        private final boolean evaluateClassifier;
        private final boolean evaluateFeaturePowerSet;

        public ClassifierParams(final String classifierType, final String productSuffix,
                final Product[] sourceProducts, final int numTrainSamples, final double minClassValue,
                final double classValStepSize, final int classLevels, final String savedClassifierName,
                final boolean doClassValQuantization, final boolean trainOnRaster, final String[] trainingBands,
                final String[] trainingVectors, final String[] featureBands, final String labelSource,
                final boolean evaluateClassifier, final boolean evaluateFeaturePowerSet) {
            this.classifierType = classifierType;
            this.productSuffix = productSuffix;
            this.sourceProducts = sourceProducts;
            this.numTrainSamples = numTrainSamples;
            this.minClassValue = minClassValue;
            this.classValStepSize = classValStepSize;
            this.classLevels = classLevels;
            this.savedClassifierName = savedClassifierName;
            this.doClassValQuantization = doClassValQuantization;
            this.trainOnRaster = trainOnRaster;
            this.trainingBands = trainingBands;
            this.trainingVectors = trainingVectors;
            this.featureBands = featureBands;
            this.labelSource = labelSource;
            this.evaluateClassifier = evaluateClassifier;
            this.evaluateFeaturePowerSet = evaluateFeaturePowerSet;
        }
    }

    public BaseClassifier(final ClassifierParams params) {
        this.params = params;
        this.classifierReport = new ClassifierReport(params.classifierType, params.savedClassifierName);
    }

    protected Object getObjectToSave(final Dataset trainDataset) {
        return mlClassifier;
    }

    protected Object getXMLInfoToSave(final BaseClassifier.ClassifierUserInfo commonInfo) {
        return commonInfo;
    }

    public String getClassifierType() {
        return params.classifierType;
    }

    public String getProductSuffix() {
        return params.productSuffix;
    }

    public String getClassifierName() {
        return params.savedClassifierName;
    }

    public Classifier getMLClassifier() {
        return mlClassifier;
    }

    public void initialize() throws OperatorException, IOException {

        checkSourceProductsValidity();

        // mask product is always the 1st product (same assumption in BaseOperatorUI)
        maskProduct = params.sourceProducts[0];

        // Check even if they are loaded from file.
        if (params.classValStepSize < 0.0) {
            throw new OperatorException("class value step size = " + params.classValStepSize);
        }
        if (params.classLevels < 2) {
            throw new OperatorException("class levels = " + params.classLevels + "; it must be at least 2");
        }
        maxClassValue = getMaxValue(params.minClassValue, params.classValStepSize, params.classLevels);

        if (!doLoadClassifier) {
            if (params.trainOnRaster && params.trainingBands == null) {
                trainingSetMaskBand = maskProduct.getBandAt(0);
            }
        }

        if (params.trainOnRaster && params.trainingBands != null && params.trainingBands.length == 1) {
            String bandName = params.trainingBands[0];
            if (params.trainingBands[0].contains("::")) {
                bandName = params.trainingBands[0].substring(0, params.trainingBands[0].indexOf("::"));
            }
            trainingSetMaskBand = maskProduct.getBand(bandName);
            if (trainingSetMaskBand == null) {
                throw new OperatorException("Fail to find training band in 1st source product: " + bandName);
            }
        }

        // TODO...
        //final int startIdx = (maskProduct == null) ? 0 : 1;
        //featureProducts = new Product[params.sourceProducts.length - startIdx];
        //System.arraycopy(params.sourceProducts, 0 + startIdx, featureProducts, 0, featureProducts.length);
        featureProducts = params.sourceProducts;

        if (trainingSetMaskBand != null && trainingSetMaskBand.isNoDataValueSet()) {
            maskNoDataValue = trainingSetMaskBand.getNoDataValue();
        }

        // polygonsAsClasses contains the names of all the polygons the user has selected to use as classes.
        // E.g., the user can create polygons named "water", "trees" and "shrubs"
        // There will be 3 classes named "water", "trees" and "shrubs".
        // All the pixels in the "water" polygon will have the class "water"
        // Get the corresponding VectorDataNode and store them in polygonVectorDataNodes.
        if (maskProduct != null && !params.trainOnRaster) {
            polygonVectorDataNodeToVectorIndex = new HashMap<>();

            if (params.trainingVectors == null || params.trainingVectors.length == 0) {
                final List<String> geometryNames = new ArrayList<>();
                final String[] nodeNames = maskProduct.getMaskGroup().getNodeNames();
                for (String name : nodeNames) {
                    geometryNames.add(name + "::" + maskProduct.getName());
                }
                if (geometryNames.size() < 2) {
                    throw new OperatorException(
                            "Cannot train on vectors because source product has less than 2 vectors");
                }
                params.trainingVectors = geometryNames.toArray(new String[geometryNames.size()]);
            }
            if (doLoadClassifier || params.trainingVectors != null) {
                if (params.trainingVectors.length == 1) {
                    throw new OperatorException("Please select two or more vectors as classes");
                }

                polygonVectorDataNodes = new VectorDataNode[params.trainingVectors.length];
                final ProductNodeGroup<VectorDataNode> vectorGroup = maskProduct.getVectorDataGroup();
                for (int i = 0; i < params.trainingVectors.length; ++i) {
                    int multiProductIndex = params.trainingVectors[i].indexOf("::");
                    String name = params.trainingVectors[i];
                    if (multiProductIndex > 0) {
                        name = params.trainingVectors[i].substring(0, multiProductIndex);
                    }

                    polygonVectorDataNodes[i] = vectorGroup.get(name);
                    if (polygonVectorDataNodes[i] == null) {
                        throw new OperatorException("Cannot find vector " + params.trainingVectors[i]);
                    }
                }

                useVectorNodeNameAsLabel = params.labelSource == null || params.labelSource.isEmpty()
                        || params.labelSource.equals(VectorNodeNameLabelSource);

                // The index is going to be the class value
                classLabelMap = new HashMap<>();
                labelClassMap = new HashMap<>();
                int classIndex = 0;
                final Set<String> attribValues = new HashSet<>();
                for (int i = 0; i < polygonVectorDataNodes.length; i++) {
                    polygonVectorDataNodeToVectorIndex.put(polygonVectorDataNodes[i], i);
                    if (useVectorNodeNameAsLabel) {
                        classIndex = i;
                        classLabelMap.put(classIndex, polygonVectorDataNodes[i].getName());
                    } else {
                        String classLabel = VectorUtils.getAttribStringValue(polygonVectorDataNodes[i],
                                params.labelSource);

                        if (!classLabelMap.values().contains(classLabel)) {
                            classLabelMap.put(classIndex, classLabel);
                            labelClassMap.put(classLabel, classIndex);
                            classIndex++;
                        }
                    }
                }
            }
        }

        //SystemUtils.LOG.info("doLoadClassifier = " + doLoadClassifier);
        //SystemUtils.LOG.info("doClassValQuantization = " + doClassValQuantization);
        //SystemUtils.LOG.info("Min class value = " + minClassValue + "; class value step size = " + classValStepSize
        //        + "; class levels = " + classLevels + "; max class value = " + maxClassValue);
    }

    public static double getMaxValue(final double minVal, final double stepSize, final int levels) {
        return minVal + stepSize * (levels - 1);
    }

    private void checkSourceProductsValidity() {

        // All the source products must have the same raster dimensions.
        // Here we are assuming that a band will have the same raster dimensions as the product it belongs to.

        sourceImageHeight = params.sourceProducts[0].getSceneRasterHeight();
        sourceImageWidth = params.sourceProducts[0].getSceneRasterWidth();

        for (int i = 1; i < params.sourceProducts.length; i++) {
            if (sourceImageHeight != params.sourceProducts[i].getSceneRasterHeight()
                    || sourceImageWidth != params.sourceProducts[i].getSceneRasterWidth()) {
                throw new OperatorException("Source products are of different dimensions");
            }
        }
    }

    /**
     * Create target product.
     */
    public Product createTargetProduct() {

        Product targetProduct = new Product(params.sourceProducts[0].getName() + getProductSuffix(),
                params.sourceProducts[0].getProductType(), sourceImageWidth, sourceImageHeight);

        ProductUtils.copyProductNodes(params.sourceProducts[0], targetProduct);

        final int dataType = (params.trainOnRaster) ? ProductData.TYPE_FLOAT32 : ProductData.TYPE_INT16;
        labelBand = new Band(LabelBandName, dataType, sourceImageWidth, sourceImageHeight);

        final String unit = (params.trainOnRaster && trainingSetMaskBand != null ? trainingSetMaskBand.getUnit()
                : "discrete classes");
        labelBand.setUnit(unit);
        final double noDataVal = params.trainOnRaster ? DOUBLE_NO_DATA_VALUE : INT_NO_DATA_VALUE;
        labelBand.setNoDataValue(noDataVal);
        labelBand.setNoDataValueUsed(true);
        labelBand.setValidPixelExpression(ConfidenceBandName + " >= 0.5");

        if (!params.trainOnRaster) {
            final IndexCoding indexCoding = new IndexCoding("Classes");
            indexCoding.addIndex("no data", INT_NO_DATA_VALUE, "no data");
            for (Integer i : classLabelMap.keySet()) {
                String label = classLabelMap.get(i);
                if (label == null || label.isEmpty())
                    label = "null";
                indexCoding.addIndex(label, i, "");
            }
            targetProduct.getIndexCodingGroup().add(indexCoding);
            labelBand.setSampleCoding(indexCoding);

            // remove training vectors
            final ProductNodeGroup<VectorDataNode> vectorDataGroup = targetProduct.getVectorDataGroup();
            for (String vector : params.trainingVectors) {
                vectorDataGroup.remove(vectorDataGroup.get(createClassLabel(vector)));
            }
        }
        targetProduct.addBand(labelBand);

        confidenceBand = new Band(ConfidenceBandName, ProductData.TYPE_FLOAT32, sourceImageWidth,
                sourceImageHeight);

        confidenceBand.setUnit("(0, 1]");
        confidenceBand.setNoDataValue(DOUBLE_NO_DATA_VALUE);
        confidenceBand.setNoDataValueUsed(true);
        targetProduct.addBand(confidenceBand);

        return targetProduct;
    }

    private static String createClassLabel(String vectorName) {
        String label = vectorName;
        if (vectorName.contains("::")) {
            label = vectorName.substring(0, vectorName.indexOf("::"));
        }
        return label;
    }

    public static boolean containsFeature(final Product product, final String[] featureNames) {
        final String[] bandnames = product.getBandNames();
        if (featureNames == null || bandnames == null) {
            return false;
        }
        for (String featureName : featureNames) {
            for (String bandname : bandnames) {
                if (bandname.contains(featureName)) {
                    return true;
                }
            }
        }
        return false;
    }

    protected double getConfidence(final Instance instance, final Object classVal) {
        final Map<Object, Double> classDis = mlClassifier.classDistribution(instance);
        return classDis.get(classVal);
    }

    public void computeTileStack(final Operator operator, final Map<Band, Tile> targetTileMap,
            final Rectangle targetRectangle, ProgressMonitor pm) throws OperatorException, IOException {

        final int x0 = targetRectangle.x;
        final int y0 = targetRectangle.y;
        final int xMax = x0 + targetRectangle.width;
        final int yMax = y0 + targetRectangle.height;
        //System.out.println("x0 = " + x0 + ", y0 = " + y0 + ", w = " + targetRectangle.width + ", h = " + targetRectangle.height);

        if (!classifierTrained) {
            if (doLoadClassifier) {
                loadClassifier(operator);
            } else {
                trainClassifier(operator, pm);
            }
        }

        final Tile labelTile = targetTileMap.get(labelBand);
        final Tile confidenceTile = targetTileMap.get(confidenceBand);
        final ProductData labelBuffer = labelTile.getDataBuffer();
        final ProductData confidenceBuffer = confidenceTile.getDataBuffer();
        final TileIndex tgtIndex = new TileIndex(labelTile);

        final Tile[] featureTiles = new Tile[featureInfoList.length];
        int i = 0;
        for (FeatureInfo feature : featureInfoList) {
            featureTiles[i++] = operator.getSourceTile(feature.featureBand, targetRectangle);
        }

        for (int y = y0; y < yMax; ++y) {
            tgtIndex.calculateStride(y);
            for (int x = x0; x < xMax; ++x) {
                final int tgtIdx = tgtIndex.getIndex(x);
                final double[] features = getFeatures(featureTiles, featureInfoList, x, y);
                if (features == null) {
                    labelBuffer.setElemDoubleAt(tgtIdx,
                            params.trainOnRaster ? DOUBLE_NO_DATA_VALUE : INT_NO_DATA_VALUE);
                    confidenceBuffer.setElemDoubleAt(tgtIdx, DOUBLE_NO_DATA_VALUE);
                    continue;
                }

                final Instance instance = new DenseInstance(features);

                double confidence = DOUBLE_NO_DATA_VALUE;
                Object classVal = mlClassifier.classify(instance);
                if (classVal == null) {
                    classVal = params.trainOnRaster ? DOUBLE_NO_DATA_VALUE : INT_NO_DATA_VALUE;
                } else {
                    confidence = getConfidence(instance, classVal);
                }
                labelBuffer.setElemDoubleAt(tgtIdx, (double) classVal);

                // classVal MUST be a key in classDis?
                //final double confidence = classDis.containsKey(classVal) ? classDis.get(classVal) : 0.0 ;
                confidenceBuffer.setElemDoubleAt(tgtIdx, confidence);
            }
        }
    }

    public static int getTotalNumBands(final Product[] products) {
        int numBands = 0;
        for (Product product : products) {
            numBands += product.getNumBands();
        }
        return numBands;
    }

    private void getVectorInstanceLists(final List<Instance> parentList, final List<Instance> trainList,
            final List<Instance> testList) {

        // This is only for the case where we get the samples a.k.a. instances from polygons.
        // parentList contains a balanced list of the instances. E.g., if there are 3 classes, namely 0, 1 and 2
        // and there are 300 Instances in parentList, there should be 100 of each class in parentList.
        // We want to put 50 of each class into trainList and testList.

        final HashMap<Integer, List<Instance>> classToInstanceListMap = new HashMap<>();
        for (Integer i : classLabelMap.keySet()) {
            classToInstanceListMap.put(i, new ArrayList<>());
        }

        for (Instance instance : parentList) {
            final int classVal = (int) ((double) instance.classValue());
            classToInstanceListMap.get(classVal).add(instance);
        }

        for (Integer i : classLabelMap.keySet()) {
            final List<Instance> list = classToInstanceListMap.get(i);
            final int listLen = list.size();
            //SystemUtils.LOG.info("classVal = " + i + " has " + listLen + " samples");

            // add every other to train or test list
            boolean addToTrainList = true;
            for (Instance instance : list) {
                if (addToTrainList) {
                    trainList.add(instance);
                    addToTrainList = false;
                } else {
                    testList.add(instance);
                    addToTrainList = true;
                }
            }

            // add first to train and then to test lists
            //trainList.addAll(list.subList(0, listLen / 2));
            //testList.addAll(list.subList(listLen / 2, listLen));
        }
    }

    public static boolean excludeBand(final String bandName) {
        for (String excludedBand : excludedBands) {
            if (bandName.startsWith(excludedBand)) {
                return true;
            }
        }
        return false;
    }

    private synchronized void trainClassifier(final Operator operator, final ProgressMonitor opPM)
            throws IOException {

        if (classifierTrained)
            return;
        try {
            if (params.featureBands == null) {
                final List<String> allFeatureBands = new ArrayList<>();
                for (Product p : params.sourceProducts) {
                    for (Band b : p.getBands()) {
                        if (b == trainingSetMaskBand)
                            continue;
                        String bandName = b.getName();
                        if (excludeBand(bandName)) {
                            continue;
                        }
                        allFeatureBands.add(bandName + "::" + p.getName());
                    }
                }
                params.featureBands = allFeatureBands.toArray(new String[allFeatureBands.size()]);
            }

            final Map<String, Product> productHashMap = new HashMap<>();
            for (Product product : params.sourceProducts) {
                productHashMap.put(product.getName(), product);
            }
            int i = 0;
            final List<FeatureInfo> featureInfos = new ArrayList<>(params.featureBands.length);
            for (String s : params.featureBands) {
                final int multiProductIndex = s.indexOf("::");
                String bandName = s;
                String productName = maskProduct.getName();
                if (multiProductIndex > 0) {
                    bandName = s.substring(0, multiProductIndex);
                    productName = s.substring(s.indexOf("::") + 2);
                }

                final Product product = productHashMap.get(productName);
                if (product != null) {
                    final Band featureBand = product.getBand(bandName);
                    if (featureBand == null) {
                        throw new OperatorException("Failed to find feature band " + s);
                    } else if (trainingSetMaskBand != null && featureBand == trainingSetMaskBand) {
                        throw new OperatorException("The training band has also been selected as a feature band");
                    }

                    FeatureInfo featureInfo = new FeatureInfo(featureBand, i);
                    featureInfos.add(featureInfo);

                    i++;
                } else {
                    throw new OperatorException("Failed to find feature product " + s);
                }
            }

            featureInfoList = featureInfos.toArray(new FeatureInfo[featureInfos.size()]);

            final LabeledInstances allLabeledInstances = getLabeledInstances(operator, params.numTrainSamples * 2,
                    featureInfoList);

            if (params.evaluateClassifier && params.evaluateFeaturePowerSet) {
                runFeaturePowerSet(operator, allLabeledInstances, featureInfoList, opPM);
            }

            if (!classifierTrained) {

                mlClassifier = createMLClassifier(featureInfoList);

                Dataset trainDataset = trainClassifier(mlClassifier, getClassifierName(), allLabeledInstances,
                        featureInfoList, false);

                saveClassifier(trainDataset);
            }

        } finally {
            classifierTrained = true;
        }
    }

    private Dataset trainClassifier(final Classifier classifier, final String name,
            final LabeledInstances labeledInstances, final FeatureInfo[] featureInfos, boolean quickEvaluation) {

        final List<Instance> trainList;
        final List<Instance> testList;
        if (params.trainOnRaster) {
            trainList = labeledInstances.instanceList.subList(0, labeledInstances.instanceList.size() / 2);
            testList = labeledInstances.instanceList.subList(labeledInstances.instanceList.size() / 2,
                    labeledInstances.instanceList.size());
        } else {
            trainList = new ArrayList<>();
            testList = new ArrayList<>();
            getVectorInstanceLists(labeledInstances.instanceList, trainList, testList);
        }

        final Dataset trainDataset = new DefaultDataset(trainList);

        buildClassifier(classifier, trainDataset);

        if (params.evaluateClassifier) {
            final Dataset testDataset = new DefaultDataset(testList);

            if (quickEvaluation) {
                runQuickEvaluation(classifier, name, labeledInstances, featureInfos, testDataset);
            } else {
                runEvaluation(classifier, labeledInstances, featureInfos, testDataset);
            }
        }

        return trainDataset;
    }

    protected void buildClassifier(final Classifier classifier, final Dataset trainDataset) {
        classifier.buildClassifier(trainDataset);
    }

    private void runEvaluation(final Classifier mlClassifier, final LabeledInstances labeledInstances,
            final FeatureInfo[] featureInfos, final Dataset testDataset) {
        final StatusProgressMonitor pm = new StatusProgressMonitor(StatusProgressMonitor.TYPE.SUBTASK);

        //Thread thread1 = new Thread(new Runnable() {
        //    @Override
        //    public void run(){
        try {
            final Evaluator evaluator = new Evaluator(mlClassifier, classifierReport);

            evaluator.evaluateClassifier(labeledInstances.labelMap, labeledInstances.instanceList, testDataset,
                    "Testing");

            evaluator.evaluateFeatures(featureInfos, testDataset, "Testing", pm);

            saveAndOpenReport(true);

        } finally {
            pm.done();
        }
        //    }
        // });
        // thread1.start();
    }

    public void saveAndOpenReport(boolean openReport) {
        try {
            classifierReport.writeReport();

            if (openReport) {
                classifierReport.openClassifierReport();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void runQuickEvaluation(final Classifier mlClassifier, final String name,
            final LabeledInstances labeledInstances, final FeatureInfo[] featureInfos, final Dataset testDataset) {

        final Evaluator evaluator = new Evaluator(mlClassifier,
                new ClassifierReport(params.classifierType, "dummy"));

        Evaluator.Score score = evaluator.evaluateClassifier(labeledInstances.labelMap,
                labeledInstances.instanceList, testDataset, "Testing");

        StringBuilder featureBands = new StringBuilder();
        for (FeatureInfo featureInfo : featureInfos) {
            featureBands.append(featureInfo.featureBand.getName());
            featureBands.append(", ");
        }
        classifierReport.addPowerSetEvaluation(
                name + ": " + "cv " + f(score.crossValidationPercent * 100) + "% " + featureBands.toString());

        /*
        final StatusProgressMonitor pm = new StatusProgressMonitor(StatusProgressMonitor.TYPE.SUBTASK);
        evaluator.evaluateFeatures(featureInfos, testDataset, "Testing", pm);
            
        StringBuilder featureScoreStr = new StringBuilder();
        for (String key : score.featureScoreMap.keySet()) {
        featureScoreStr.append(String.format("%-20s", key));
        featureScoreStr.append(score.featureScoreMap.get(key) + '\n');
        }
        classifierReport.addPowerSetEvaluation(featureScoreStr.toString());
        */

        if (score.crossValidationPercent > topClassifierPercent) {
            updateTopSpot(score.crossValidationPercent, name, featureInfos);
        }
    }

    private static String f(double val) {
        return String.format("%-6.2f", val);
    }

    private synchronized void updateTopSpot(final double percentCorrect, final String name,
            final FeatureInfo[] featureInfos) {
        topClassifierPercent = percentCorrect;
        topClassifierName = name;
        topFeatureInfoList = featureInfos;
    }

    private void runFeaturePowerSet(final Operator operator, final LabeledInstances allLabeledInstances,
            final FeatureInfo[] completeFeatureInfoList, final ProgressMonitor opPM) {
        final StatusProgressMonitor pm = new StatusProgressMonitor(StatusProgressMonitor.TYPE.SUBTASK);

        try {
            // get the power set of all features
            Set<Set<FeatureInfo>> featurePowerSet = Sets
                    .powerSet(ImmutableSet.copyOf(Arrays.asList(completeFeatureInfoList)));
            List<Set<FeatureInfo>> featureSetList = new ArrayList<>();
            for (Set<FeatureInfo> featureSet : featurePowerSet) {
                if (featureSet.size() < minPowerSetSize || featureSet.size() > maxPowerSetSize)
                    continue;
                featureSetList.add(featureSet);
            }
            pm.beginTask("Evaluating feature power set", featureSetList.size());

            int cnt = 1;
            for (Set<FeatureInfo> featureSet : featureSetList) {
                if (opPM.isCanceled()) {
                    break;
                }

                final FeatureInfo[] featureInfos = featureSet.toArray(new FeatureInfo[featureSet.size()]);

                Classifier setClassifier = createMLClassifier(featureInfos);

                // create subset of labeledInstances
                // LabeledInstances subsetLabeledInstances = createSubsetLabeledInstances(featureInfos, allLabeledInstances);

                final LabeledInstances allLabeledInstances2 = getLabeledInstances(operator,
                        params.numTrainSamples * 2, featureInfoList);

                trainClassifier(setClassifier, getClassifierName() + '.' + cnt, allLabeledInstances2, featureInfos,
                        true);
                ++cnt;
                pm.worked(1);
            }

            classifierReport.setTopClassifier("TOP Classifier = " + topClassifierName + " at "
                    + String.format("%-6.2f", topClassifierPercent * 100) + '%');

            if (topFeatureInfoList != null) {
                featureInfoList = topFeatureInfoList;

                mlClassifier = createMLClassifier(featureInfoList);

                // create subset of labeledInstances
                //LabeledInstances subsetLabeledInstances = createSubsetLabeledInstances(featureInfoList, allLabeledInstances);

                final LabeledInstances allLabeledInstances2 = getLabeledInstances(operator,
                        params.numTrainSamples * 2, featureInfoList);

                Dataset trainDataset = trainClassifier(mlClassifier, getClassifierName(), allLabeledInstances2,
                        featureInfoList, false);

                saveClassifier(trainDataset);

                classifierTrained = true;
            }

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            pm.done();
        }
    }

    private LabeledInstances createSubsetLabeledInstances(final FeatureInfo[] featureInfos,
            final LabeledInstances allLabeledInstances) {
        final List<Instance> instanceList = new ArrayList<>();

        final List<Integer> featureIndexList = new ArrayList<>();
        for (FeatureInfo fi : featureInfos) {
            String name = fi.featureBand.getName();
            int i = 0;
            for (FeatureInfo origFI : featureInfoList) {
                if (name.equals(origFI.featureBand.getName())) {
                    featureIndexList.add(i);
                    break;
                }
                ++i;
            }
        }

        for (Instance instance : allLabeledInstances.instanceList) {
            instance.keySet();

            Instance newInstance = new DenseInstance(featureIndexList.size());
            newInstance.setClassValue(instance.classValue());

            int i = 0;
            for (Integer index : featureIndexList) {
                newInstance.put(i++, instance.get(index));
            }

            instanceList.add(newInstance);
        }

        return new LabeledInstances(allLabeledInstances.labelMap, instanceList);
    }

    private Path getClassifierFilePath() throws IOException {
        final Path classifierDir = SystemUtils.getAuxDataPath().resolve(CLASSIFIER_ROOT_FOLDER)
                .resolve(params.classifierType);
        if (Files.notExists(classifierDir)) {
            Files.createDirectories(classifierDir);
        }
        return classifierDir.resolve(params.savedClassifierName + CLASSIFIER_FILE_EXTENSION);
    }

    public static void findBandInProducts(final Product[] products, final String bandName, final int[] indices) {
        // indices[0] indexes into featureProducts
        // indices[1] indexes into featureProducts[indices[0].getBandAt()
        indices[0] = -1;
        indices[1] = -1;
        for (int i = 0; i < products.length; i++) {
            for (int j = 0; j < products[i].getNumBands(); j++) {
                if (products[i].getBandAt(j).getName().contains(bandName)) {
                    indices[0] = i;
                    indices[1] = j;
                    return;
                }
            }
        }
    }

    private void loadClassifierDescriptor() {

        try {
            final Path filePath = getClassifierFilePath();

            final FileInputStream fis = new FileInputStream(filePath.toString());
            try (final ObjectInputStream in = new ObjectInputStream(fis)) {

                loadedClassifierDescriptor = (ClassifierDescriptor) in.readObject();

                final String cType = loadedClassifierDescriptor.getClassifierType();
                if (!cType.equals(params.classifierType)) {
                    throw new OperatorException("Loaded classifier is " + cType + " NOT " + params.classifierType);
                }

                params.doClassValQuantization = loadedClassifierDescriptor.getDoClassValQuantization();
                params.minClassValue = loadedClassifierDescriptor.getMinClassValue();
                params.classValStepSize = loadedClassifierDescriptor.getClassValStepSize();
                params.classLevels = loadedClassifierDescriptor.getClassLevels();
                params.trainingVectors = loadedClassifierDescriptor.getPolygonsAsClasses();
            }
        } catch (Exception ex) {
            throw new OperatorException("Failed to load classifier " + ex.getMessage());
        }
    }

    private synchronized void loadClassifier(final Operator operator) throws IOException {

        if (classifierTrained)
            return;

        try {
            loadClassifierDescriptor();

            final String[] featureNames = loadedClassifierDescriptor.getFeatureNames();

            final int totalAvailableFeatures = getTotalNumBands(featureProducts);
            if (featureNames.length > totalAvailableFeatures) {
                throw new OperatorException("classifier expects " + featureNames.length
                        + " features; source product(s) only have " + totalAvailableFeatures);
            }

            mlClassifier = retrieveMLClassifier(loadedClassifierDescriptor);

            final double[] featureMinValues = loadedClassifierDescriptor.getFeatureMinValues();
            final double[] featureMaxValues = loadedClassifierDescriptor.getFeatureMaxValues();

            SystemUtils.LOG.info("*** Loaded " + params.classifierType + " classifier (filename = "
                    + params.savedClassifierName + ") to predict " + loadedClassifierDescriptor.getClassName());

            int numFeatures = featureNames.length;
            final List<FeatureInfo> featureInfos = new ArrayList<>(featureNames.length);
            final Set<Pair<Integer, Integer>> indicesSet = new HashSet<>();
            for (int i = 0; i < numFeatures; i++) {
                final int[] indices = new int[2];
                findBandInProducts(featureProducts, featureNames[i], indices);
                if (indices[0] < 0) {
                    throw new OperatorException(
                            "Failed to find feature band " + featureNames[i] + " in source product");
                }
                final Pair<Integer, Integer> idxPair = new Pair(indices[0], indices[1]);
                if (indicesSet.contains(idxPair)) {
                    throw new OperatorException(featureProducts[indices[0]].getBandAt(indices[1]).getName()
                            + " for " + featureNames[i] + " has already appeared as an earlier feature");
                }
                indicesSet.add(idxPair);
                Band featureBand = featureProducts[indices[0]].getBandAt(indices[1]);
                double noDataValue = DOUBLE_NO_DATA_VALUE;
                if (featureBand.isNoDataValueSet()) {
                    noDataValue = featureBand.getNoDataValue();
                }

                double offset = featureMinValues[i];
                double scale = 1.0 / (featureMaxValues[i] - offset);

                featureInfos.add(new FeatureInfo(featureBand, i, noDataValue, offset, scale));
            }

            featureInfoList = featureInfos.toArray(new FeatureInfo[featureInfos.size()]);

        } catch (Exception ex) {
            throw new OperatorException("Error loading or using loaded classifier (" + ex.getMessage() + ')');
        }

        if (params.evaluateClassifier && trainingSetMaskBand != null) {
            final LabeledInstances labeledInstances = getLabeledInstances(operator, params.numTrainSamples,
                    featureInfoList);
            final Dataset testDataset = new DefaultDataset(labeledInstances.instanceList);

            runEvaluation(mlClassifier, labeledInstances, featureInfoList, testDataset);
        }

        classifierTrained = true;
    }

    private static String getFirstPartOfExpression(final String polygonName, final int polygonIdx) {
        return '\'' + polygonName + "' ? " + polygonIdx + " : ";
    }

    private static String getExpression(final VectorDataNode[] polygons,
            final Map<VectorDataNode, Integer> indexMap) {

        if (polygons == null || indexMap == null) {
            return null;
        }

        final VectorDataNode firstNode = polygons[0];
        String expression = getFirstPartOfExpression(firstNode.getName(), indexMap.get(firstNode)) + NOT_IN_POLYGON;

        for (int i = 1; i < polygons.length; i++) {
            final VectorDataNode nextNode = polygons[i];
            expression = getFirstPartOfExpression(nextNode.getName(), indexMap.get(nextNode)) + '(' + expression
                    + ')';
        }

        return expression;
    }

    private LabeledInstances getInstanceListFromPolygons(final Operator operator, final int numInstances,
            final FeatureInfo[] featureInfos) throws OperatorException {

        final Dimension tileSize = new Dimension(512, 512);
        final Rectangle[] tileRectangles = OperatorUtils.getAllTileRectangles(maskProduct, tileSize, 0);

        final StatusProgressMonitor status = new StatusProgressMonitor(StatusProgressMonitor.TYPE.SUBTASK);
        status.beginTask("Extracting data... ", tileRectangles.length);

        final List<Instance> instanceList = new ArrayList<>();
        final ThreadManager threadManager = new ThreadManager();

        final int numClasses = classLabelMap.size();
        final int[] instancesCnt = new int[numClasses];
        for (int i = 0; i < instancesCnt.length; i++) {
            instancesCnt[i] = 0;
        }
        final int maxCnt = (int) Math.ceil(numInstances / (double) numClasses);

        //SystemUtils.LOG.info("getInstanceListFromPolygons maxCnt = " + maxCnt + " numInstances = " + numInstances);
        //SystemUtils.LOG.info("getInstanceListFromPolygons #tile rectangles = " + tileRectangles.length);

        try {
            final CoordinateReferenceSystem srcCRS = params.sourceProducts[0].getSceneCRS();

            // Loop through each rectangle to see if it intersects with any of the class polygons, if it does, then
            // the intersecting pixel is added to instanceList
            for (int i = 0; i < tileRectangles.length; i++) {

                final Rectangle rectangle = tileRectangles[i];

                // Get the class polygons that intersect this rectangle
                final VectorDataNode[] polygons = VectorUtils.getPolygonsForOneRectangle(rectangle,
                        params.sourceProducts[0].getSceneGeoCoding(), polygonVectorDataNodes);
                if (polygons.length == 0) {
                    continue;
                }

                final String virtualBandName = "tmpVirtualBand_" + i;
                final String expression = getExpression(polygons, polygonVectorDataNodeToVectorIndex);
                // The virtual band will contain the class value
                final Band virtualBand = new VirtualBand(virtualBandName, ProductData.TYPE_INT16, sourceImageWidth,
                        sourceImageHeight, expression);
                maskProduct.addBand(virtualBand);

                //System.out.println(virtualBandName + ": " + expression);

                final Thread worker = new Thread() {

                    @Override
                    public void run() {
                        try {
                            final int x0 = rectangle.x, y0 = rectangle.y;
                            final int w = rectangle.width, h = rectangle.height;
                            final int xMax = x0 + w, yMax = y0 + h;

                            final Tile virtualBandTile = operator.getSourceTile(virtualBand, rectangle);
                            final ProductData virtualBandData = virtualBandTile.getDataBuffer();

                            final Tile[] featureTiles = new Tile[featureInfos.length];
                            final ProductData[] featureBuffers = new ProductData[featureInfos.length];

                            for (int j = 0; j < featureInfos.length; j++) {
                                featureTiles[j] = operator.getSourceTile(featureInfos[j].featureBand, rectangle);
                                featureBuffers[j] = featureTiles[j].getDataBuffer();
                            }

                            for (int y = y0; y < yMax; ++y) {
                                for (int x = x0; x < xMax; ++x) {

                                    int classVal = virtualBandData
                                            .getElemIntAt(virtualBandTile.getDataBufferIndex(x, y));
                                    if (classVal < 0) {
                                        // This pixel is not inside a class polygon
                                        continue;
                                    }

                                    // Get the features values for this pixel
                                    final double[] features = getFeatures(featureTiles, featureInfos, x, y);
                                    if (features == null) {
                                        continue;
                                    }

                                    final Instance instance = new DenseInstance(features);
                                    if (useVectorNodeNameAsLabel) {
                                        instance.setClassValue((double) classVal);
                                    } else {
                                        int vectorIndex = classVal;
                                        String val = VectorUtils.getAttribStringValue(
                                                polygonVectorDataNodes[vectorIndex], params.labelSource);
                                        classVal = labelClassMap.get(val);
                                        instance.setClassValue((double) classVal);
                                    }

                                    synchronized (instanceList) {
                                        if (instanceList.size() < numInstances) {
                                            if (instancesCnt[classVal] < maxCnt) {
                                                instanceList.add(instance);
                                                instancesCnt[classVal]++;
                                                if (instanceList.size() >= numInstances) {
                                                    return;
                                                }
                                            }
                                        } else {
                                            return;
                                        }
                                    }
                                }
                            }
                        } catch (Exception e) {
                            SystemUtils.LOG.severe("Error retrieving features from polygons " + e.getMessage());
                        }
                    }
                };

                threadManager.add(worker);
                status.worked(1);
            }

            threadManager.finish();

            for (int i = 0; i < tileRectangles.length; i++) {

                Band band = maskProduct.getBand("tmpVirtualBand_" + i);
                if (band != null) {
                    maskProduct.removeBand(band);
                }
            }

        } catch (Throwable e) {
            OperatorUtils.catchOperatorException(params.classifierType + " getTrainingData from polygons ", e);
        } finally {
            status.done();
        }

        final Map<Double, String> labelMap = new HashMap<>();
        int i = 0;
        for (Integer classIndex : classLabelMap.keySet()) {
            labelMap.put((double) classIndex, classLabelMap.get(classIndex));
        }
        return new LabeledInstances(labelMap, instanceList);
    }

    private LabeledInstances getInstanceListFromMaskProduct(final Operator operator, final int numInstances,
            final FeatureInfo[] featureInfos) throws OperatorException {

        final Dimension tileSize = new Dimension(20, 10);
        final Rectangle[] tileRectangles = OperatorUtils.getAllTileRectangles(maskProduct, tileSize, 0);
        final StatusProgressMonitor status = new StatusProgressMonitor(StatusProgressMonitor.TYPE.SUBTASK);
        status.beginTask("Getting training data... ", tileRectangles.length);

        final List<Instance> instanceList = new ArrayList<>();

        try {
            final ThreadManager threadManager = new ThreadManager();

            for (final Rectangle rectangle : tileRectangles) {
                final Thread worker = new Thread() {

                    final int xMin = rectangle.x;
                    final int xMax = rectangle.x + rectangle.width;
                    final int yMin = rectangle.y;
                    final int yMax = rectangle.y + rectangle.height;

                    final Tile maskTile = operator.getSourceTile(trainingSetMaskBand, rectangle);
                    final Tile[] featureTiles = new Tile[featureInfos.length];

                    @Override
                    public void run() {
                        int i = 0;
                        for (FeatureInfo featureInfo : featureInfos) {
                            featureTiles[i++] = operator.getSourceTile(featureInfo.featureBand, rectangle);
                        }

                        getData(xMin, xMax, yMin, yMax, maskTile, featureTiles, numInstances, maskNoDataValue,
                                instanceList);
                    }
                };

                threadManager.add(worker);

                status.worked(1);
            }

            threadManager.finish();

            //SystemUtils.LOG.info("instanceList.size = " + instanceList.size());
            /*for (int i = 0; i < 3; i++) {
            dumpInstance(instanceList.get(i));
            }*/

        } catch (Throwable e) {
            OperatorUtils.catchOperatorException(params.classifierType + " getTrainingData ", e);
        } finally {
            status.done();
        }

        final Map<Double, String> labelMap = new HashMap<>();
        labelMap.put(0.0, trainingSetMaskBand.getName());
        return new LabeledInstances(labelMap, instanceList);
    }

    private static class LabeledInstances {
        final Map<Double, String> labelMap;
        final List<Instance> instanceList;

        LabeledInstances(final Map<Double, String> labelMap, List<Instance> instancesList) {
            this.labelMap = labelMap;
            this.instanceList = instancesList;
        }
    }

    private LabeledInstances getLabeledInstances(final Operator operator, final int numInstances,
            final FeatureInfo[] featureInfos) throws OperatorException {
        if (params.trainOnRaster) {
            return getInstanceListFromMaskProduct(operator, numInstances, featureInfos);
        } else {
            return getInstanceListFromPolygons(operator, numInstances, featureInfos);
        }
    }

    private double quantize(double val) {
        if (!params.doClassValQuantization) {
            return val;
        }
        return VectorUtils.quantize(val, params.minClassValue, maxClassValue, params.classValStepSize);
    }

    private void getData(final int xMin, final int xMax, final int yMin, final int yMax, final Tile maskTile,
            final Tile[] featureTiles, final int maxSamples, final double maskNoDataValue,
            List<Instance> instanceList) {

        for (int y = yMin; y < yMax; ++y) {
            for (int x = xMin; x < xMax; ++x) {
                final double maskValue = maskTile.getDataBuffer()
                        .getElemDoubleAt(maskTile.getDataBufferIndex(x, y));
                if (Double.isNaN(maskValue) || maskValue == maskNoDataValue) {
                    continue;
                }
                final double[] features = getFeatures(featureTiles, featureInfoList, x, y);
                if (features == null) {
                    continue;
                }

                final Instance instance = new DenseInstance(features);
                instance.setClassValue(quantize(maskValue));
                synchronized (instanceList) {
                    if (instanceList.size() < maxSamples) {
                        instanceList.add(instance);
                        if (instanceList.size() >= maxSamples) {
                            return;
                        }
                    } else {
                        return;
                    }
                }
            }
        }
    }

    private static double[] getFeatures(final Tile[] featureTiles, FeatureInfo[] featureInfos, final int x,
            final int y) {
        final double[] features = new double[featureTiles.length];
        for (int i = 0; i < featureTiles.length; ++i) {
            double val = featureTiles[i].getDataBuffer().getElemDoubleAt(featureTiles[i].getDataBufferIndex(x, y));
            if (val == featureInfos[i].featureNoDataValue) {
                return null;
            }

            // scale the value to [0, 1]
            val = (val - featureInfos[i].featureOffsetValue) * featureInfos[i].featureScaleValue;
            if (val > 1.0) {
                val = 1.0;
            } else if (val < 0.0) {
                val = 0.0;
            }
            features[i] = val;
        }
        return features;
    }

    private void saveClassifier(final Dataset trainDataset) throws IOException {

        // First save the classifier and other info required by the software to use it.

        Object objectToSave = getObjectToSave(trainDataset);

        final String className = trainingSetMaskBand == null ? "???"
                : StackUtils.getBandNameWithoutDate(trainingSetMaskBand.getName());
        //SystemUtils.LOG.info("*** Save " + classifierType + " classifier (filename = " + savedClassifierName
        //                             + ") to predict " + className);

        final Object[] sortedObjects = ClassifierAttributeEvaluation.getSortedObjects(trainDataset.classes());
        final double[] sortedClassValues = new double[sortedObjects.length];
        for (int i = 0; i < sortedObjects.length; i++) {
            sortedClassValues[i] = (double) sortedObjects[i];
            //SystemUtils.LOG.info("********* sortedClassValues[" + i + "] = " + sortedClassValues[i]);
        }

        // Order is IMPORTANT
        final String[] featureNames = new String[featureInfoList.length];
        final double[] featureMinValues = new double[featureInfoList.length];
        final double[] featureMaxValues = new double[featureInfoList.length];
        for (int i = 0; i < featureNames.length; i++) {
            final Band featureBand = featureInfoList[i].featureBand;
            featureNames[i] = featureBand.getName();
            if (featureNames[i].contains(StackUtils.MST) || featureNames[i].contains(StackUtils.SLV))
                featureNames[i] = StackUtils.getBandNameWithoutDate(featureNames[i]);
            featureMinValues[i] = featureBand.getStx().getMinimum();
            featureMaxValues[i] = featureBand.getStx().getMaximum();
        }

        final String classUnit = labelBand.getUnit();
        ClassifierDescriptor classifierDescriptor = new ClassifierDescriptor(params.classifierType,
                params.savedClassifierName, objectToSave, sortedClassValues, className, classUnit, featureNames,
                featureMinValues, featureMaxValues, params.doClassValQuantization, params.minClassValue,
                params.classValStepSize, params.classLevels, params.trainingVectors);

        final Path filePath = getClassifierFilePath();

        FileOutputStream fos;
        ObjectOutputStream out;
        try {
            fos = new FileOutputStream(filePath.toString());
            out = new ObjectOutputStream(fos);
            out.writeObject(classifierDescriptor);
            out.close();
        } catch (Exception ex) {
            throw new OperatorException("Failed to save classifier " + ex.getMessage());
        }

        // Now save in an xml file what the user needs to know to prepare the source products

        ClassifierUserInfo classifierUserInfo = new ClassifierUserInfo(params.savedClassifierName,
                params.classifierType, className, params.numTrainSamples, sortedClassValues, featureInfoList.length,
                params.trainingBands, params.trainingVectors, featureNames,
                (params.doClassValQuantization ? params.minClassValue : 0.0),
                (params.doClassValQuantization ? params.classValStepSize : 0.0),
                (params.doClassValQuantization ? params.classLevels : -1),
                (params.doClassValQuantization ? maxClassValue : 0.0));

        Object xmlToSave = getXMLInfoToSave(classifierUserInfo);

        final XStream xstream = new XStream();
        xstream.processAnnotations(xmlToSave.getClass());
        final String xmlContent = xstream.toXML(xmlToSave);
        final File infoFile = filePath.getParent()
                .resolve(params.savedClassifierName + CLASSIFIER_USER_INFO_FILE_EXTENSION).toFile();
        FileWriter fileWriter = new FileWriter(infoFile);
        fileWriter.write(xmlContent);
        fileWriter.flush();
        fileWriter.close();
    }

    private static void dumpInstance(Instance instance) {
        SystemUtils.LOG.info(" Class value = " + instance.classValue());
        for (int i = 0; i < instance.noAttributes(); i++) {
            SystemUtils.LOG.info(" attr " + i + ": " + instance.value(i));
        }
    }

    public static class FeatureInfo implements Comparable<FeatureInfo> {
        public final Band featureBand;
        public double featureNoDataValue;
        public final double featureOffsetValue;
        public final double featureScaleValue;
        private final int id;

        public FeatureInfo(Band featureBand, int id) {
            this.featureBand = featureBand;
            this.id = id;

            featureNoDataValue = DOUBLE_NO_DATA_VALUE;
            if (featureBand.isNoDataValueSet()) {
                featureNoDataValue = featureBand.getNoDataValue();
            }
            featureOffsetValue = featureBand.getStx().getMinimum();
            featureScaleValue = 1.0 / (featureBand.getStx().getMaximum() - featureOffsetValue);
        }

        public FeatureInfo(Band featureBand, int id, double featureNoDataValue, double featureOffsetValue,
                double featureScaleValue) {
            this.featureBand = featureBand;
            this.id = id;
            this.featureNoDataValue = featureNoDataValue;
            this.featureOffsetValue = featureOffsetValue;
            this.featureScaleValue = featureScaleValue;
        }

        public int compareTo(FeatureInfo o) {
            return Integer.compare(id, o.id);
        }
    }

    public static class ClassifierUserInfo {
        private String classifierFilename;
        private String classifierType;
        private String className; // E.g., biomass or landcover classes
        private int numSamples;
        private double[] sortedClasses;
        private int numFeatures;
        private String[] trainingBands; // can be null
        private String[] trainingVectors; // can be null
        private String[] featureNames;

        // If quantization is not done, then classLevels is set to -1
        private double minClassValue;
        private double classValStepSize;
        private int classLevels;
        private double maxClassValue;

        public ClassifierUserInfo(final String classifierFilename, final String classifierType,
                final String className, final int numSamples, final double[] sortedClasses, final int numFeatures,
                final String[] trainingBands, final String[] trainingVectors, final String[] featureNames,
                final double minClassValue, final double classValStepSize, final int classLevels,
                final double maxClassValue) {
            this.classifierFilename = classifierFilename;
            this.classifierType = classifierType;
            this.className = className;
            this.numSamples = numSamples;
            this.sortedClasses = sortedClasses;
            this.numFeatures = numFeatures;
            this.trainingBands = trainingBands;
            this.trainingVectors = trainingVectors;
            this.featureNames = featureNames;
            this.minClassValue = minClassValue;
            this.classValStepSize = classValStepSize;
            this.classLevels = classLevels;
            this.maxClassValue = maxClassValue;
        }
    }
}