org.datacleaner.components.machinelearning.MLClassificationTrainingAnalyzer.java Source code

Java tutorial

Introduction

Here is the source code for org.datacleaner.components.machinelearning.MLClassificationTrainingAnalyzer.java

Source

/**
 * DataCleaner (community edition)
 * Copyright (C) 2014 Free Software Foundation, Inc.
 *
 * This copyrighted material is made available to anyone wishing to use, modify,
 * copy, or redistribute it subject to the terms and conditions of the GNU
 * Lesser General Public License, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
 * for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this distribution; if not, write to:
 * Free Software Foundation, Inc.
 * 51 Franklin Street, Fifth Floor
 * Boston, MA  02110-1301  USA
 */
package org.datacleaner.components.machinelearning;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import org.apache.commons.lang.SerializationUtils;
import org.apache.metamodel.util.CollectionUtils;
import org.apache.metamodel.util.HasNameMapper;
import org.datacleaner.api.Categorized;
import org.datacleaner.api.Configured;
import org.datacleaner.api.Description;
import org.datacleaner.api.Initialize;
import org.datacleaner.api.InputColumn;
import org.datacleaner.api.InputRow;
import org.datacleaner.api.NumberProperty;
import org.datacleaner.components.machinelearning.api.MLClassificationRecord;
import org.datacleaner.components.machinelearning.api.MLClassificationTrainer;
import org.datacleaner.components.machinelearning.api.MLClassifier;
import org.datacleaner.components.machinelearning.api.MLFeatureModifier;
import org.datacleaner.components.machinelearning.api.MLFeatureModifierBuilder;
import org.datacleaner.components.machinelearning.api.MLFeatureModifierBuilderFactory;
import org.datacleaner.components.machinelearning.api.MLFeatureModifierType;
import org.datacleaner.components.machinelearning.api.MLTrainerCallback;
import org.datacleaner.components.machinelearning.api.MLTrainingConstraints;
import org.datacleaner.components.machinelearning.api.MLTrainingOptions;
import org.datacleaner.components.machinelearning.impl.MLClassificationRecordImpl;
import org.datacleaner.components.machinelearning.impl.MLFeatureModifierBuilderFactoryImpl;
import org.datacleaner.components.machinelearning.impl.MLFeatureUtils;
import org.datacleaner.result.Crosstab;
import org.datacleaner.util.Percentage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.io.Files;

@Categorized(MachineLearningCategory.class)
public abstract class MLClassificationTrainingAnalyzer extends MLTrainingAnalyzer<MLClassificationAnalyzerResult> {

    private static final Logger logger = LoggerFactory.getLogger(MLClassificationTrainingAnalyzer.class);
    private static final MLFeatureModifierBuilderFactory featureModifierBuilderFactory = new MLFeatureModifierBuilderFactoryImpl();

    @Configured
    InputColumn<?> classification;

    @Configured
    @Description("Determine how much (if any) of the records should be used for cross-validation.")
    @NumberProperty(negative = false)
    Percentage crossValidationSampleRate = new Percentage(10);

    private AtomicInteger recordCounter;
    private Collection<MLClassificationRecord> trainingRecords;
    private Collection<MLClassificationRecord> crossValidationRecords;
    private List<MLFeatureModifierBuilder> featureModifierBuilders;

    @Initialize
    public void init() {
        recordCounter = new AtomicInteger();
        trainingRecords = new ConcurrentLinkedQueue<>();
        crossValidationRecords = new ConcurrentLinkedQueue<>();
        featureModifierBuilders = new ArrayList<>(featureModifierTypes.length);

        final int maxFeatures = maxFeaturesGeneratedPerColumn == null ? -1 : maxFeaturesGeneratedPerColumn;
        final MLTrainingConstraints constraints = new MLTrainingConstraints(maxFeatures,
                includeUniqueValueFeatures);
        for (MLFeatureModifierType featureModifierType : featureModifierTypes) {
            final MLFeatureModifierBuilder featureModifierBuilder = featureModifierBuilderFactory
                    .create(featureModifierType, constraints);
            featureModifierBuilders.add(featureModifierBuilder);
        }
    }

    @Override
    public void run(InputRow row, int distinctCount) {
        final MLClassificationRecord record = MLClassificationRecordImpl.forTraining(row, classification,
                featureColumns);
        if (record == null) {
            return;
        }

        final Object[] recordValues = record.getRecordValues();
        for (int i = 0; i < recordValues.length; i++) {
            final MLFeatureModifierBuilder featureModifierBuilder = featureModifierBuilders.get(i);
            featureModifierBuilder.addRecordValue(recordValues[i]);
        }

        final int recordNumber = recordCounter.incrementAndGet();
        if (recordNumber % 100 > crossValidationSampleRate.getNominator()) {
            trainingRecords.add(record);
        } else {
            crossValidationRecords.add(record);
        }
    }

    @Override
    public MLClassificationAnalyzerResult getResult() {
        final List<MLFeatureModifier> featureModifiers = featureModifierBuilders.stream()
                .map(MLFeatureModifierBuilder::build).collect(Collectors.toList());
        final List<String> columnNames = CollectionUtils.map(featureColumns, new HasNameMapper());
        final MLTrainingOptions options = new MLTrainingOptions(classification.getDataType(), columnNames,
                featureModifiers);

        final MLClassificationTrainer trainer = createTrainer(options);
        log("Training model starting. Records=" + trainingRecords.size() + ", Columns=" + columnNames.size()
                + ", Features=" + MLFeatureUtils.getFeatureCount(featureModifiers) + ".");
        final MLClassifier classifier = trainer.train(trainingRecords, featureModifiers, new MLTrainerCallback() {
            @Override
            public void epochDone(int epochNo, int expectedEpochs) {
                if (expectedEpochs > 1) {
                    log("Training progress: Epoch " + epochNo + " of " + expectedEpochs + " done.");
                }
            }
        });

        if (saveModelToFile != null) {
            logger.info("Saving model to file: {}", saveModelToFile);
            try {
                final byte[] bytes = SerializationUtils.serialize(classifier);
                Files.write(bytes, saveModelToFile);
            } catch (IOException e) {
                throw new UncheckedIOException("Failed to save model to file: " + saveModelToFile, e);
            }
        }

        log("Trained model. Creating evaluation matrices.");

        final Crosstab<Integer> trainedRecordsConfusionMatrix = createConfusionMatrixCrosstab(classifier,
                trainingRecords);
        final Crosstab<Integer> crossValidationConfusionMatrix = createConfusionMatrixCrosstab(classifier,
                crossValidationRecords);

        return new MLClassificationAnalyzerResult(classifier, trainedRecordsConfusionMatrix,
                crossValidationConfusionMatrix);
    }

    protected abstract MLClassificationTrainer createTrainer(MLTrainingOptions options);

    private static Crosstab<Integer> createConfusionMatrixCrosstab(MLClassifier classifier,
            Collection<MLClassificationRecord> records) {
        final MLConfusionMatrixBuilder builder = new MLConfusionMatrixBuilder(classifier);
        for (MLClassificationRecord record : records) {
            builder.append(record);
        }
        return builder.build();
    }
}