com.mothsoft.alexis.engine.predictive.OpenNLPMaxentModelTrainerTask.java Source code

Java tutorial

Introduction

Here is the source code for com.mothsoft.alexis.engine.predictive.OpenNLPMaxentModelTrainerTask.java

Source

/*   Copyright 2012 Tim Garrett, Mothsoft LLC
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package com.mothsoft.alexis.engine.predictive;

import java.io.File;
import java.io.IOException;
import java.sql.Timestamp;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;

import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;

import opennlp.maxent.GIS;
import opennlp.maxent.GISModel;
import opennlp.maxent.io.GISModelWriter;
import opennlp.maxent.io.SuffixSensitiveGISModelWriter;
import opennlp.model.DataIndexer;
import opennlp.model.Event;
import opennlp.model.EventStream;
import opennlp.model.TwoPassDataIndexer;

import org.apache.commons.lang.time.StopWatch;
import org.apache.log4j.Logger;
import org.hibernate.ScrollableResults;
import org.hibernate.Session;
import org.hibernate.ejb.HibernateEntityManager;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionCallbackWithoutResult;
import org.springframework.transaction.support.TransactionTemplate;

import com.mothsoft.alexis.dao.DataSetPointDao;
import com.mothsoft.alexis.dao.DocumentDao;
import com.mothsoft.alexis.dao.ModelDao;
import com.mothsoft.alexis.domain.DataSetPoint;
import com.mothsoft.alexis.domain.Document;
import com.mothsoft.alexis.domain.DocumentState;
import com.mothsoft.alexis.domain.Model;
import com.mothsoft.alexis.domain.ModelState;
import com.mothsoft.alexis.domain.SortOrder;
import com.mothsoft.alexis.domain.TimeUnits;
import com.mothsoft.alexis.engine.Task;

public class OpenNLPMaxentModelTrainerTask extends AbstractModelTrainer implements ModelTrainer, Task {

    private static final Logger logger = Logger.getLogger(OpenNLPMaxentModelTrainerTask.class);

    private static final String OUTCOME_FORMAT = "+%d%s:%f";
    private static final String BIN_GZ_EXT = ".bin.gz";

    private DataSetPointDao dataSetPointDao;
    private DocumentDao documentDao;
    private ModelDao modelDao;
    private TransactionTemplate transactionTemplate;
    private File baseDirectory;
    private int iterations;
    private int cutoff;

    @PersistenceContext
    private EntityManager em;

    public OpenNLPMaxentModelTrainerTask() {
        super();
    }

    public void setBaseDirectory(File baseDirectory) {
        this.baseDirectory = baseDirectory;
    }

    public void setDataSetPointDao(DataSetPointDao dataSetPointDao) {
        this.dataSetPointDao = dataSetPointDao;
    }

    public void setDocumentDao(final DocumentDao documentDao) {
        this.documentDao = documentDao;
    }

    public void setModelDao(ModelDao modelDao) {
        this.modelDao = modelDao;
    }

    public void setTransactionManager(final PlatformTransactionManager transactionManager) {
        final TransactionDefinition transactionDefinition = new DefaultTransactionDefinition(
                DefaultTransactionDefinition.PROPAGATION_REQUIRES_NEW);
        this.transactionTemplate = new TransactionTemplate(transactionManager, transactionDefinition);
    }

    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    public void setCutoff(int cutoff) {
        this.cutoff = cutoff;
    }

    protected DocumentDao getDocumentDao() {
        return this.documentDao;
    }

    @Transactional
    @Override
    public void execute() {
        final Long modelId = findAndMark();

        if (modelId != null) {
            logger.info(String.format("Training model %d", modelId));

            final StopWatch stopWatch = new StopWatch();
            stopWatch.start();
            final Model model = this.modelDao.get(modelId);
            train(model);
            stopWatch.stop();

            logger.info(String.format("Training model %d took: %s", modelId, stopWatch.toString()));
        }
    }

    @Override
    public void train(Model model) {
        final int lookahead = model.getLookahead();
        final TimeUnits timeUnits = model.getTimeUnits();
        final long durationOfUnit = timeUnits.getDuration();

        final Date startDate = TimeUnits.floor(model.getStartDate(), timeUnits);
        final Date endDate = TimeUnits.ceil(model.getEndDate(), timeUnits);

        final Timestamp startDatePoints = new Timestamp(startDate.getTime() - durationOfUnit);
        final Timestamp endDatePoints = new Timestamp(endDate.getTime() + (lookahead * durationOfUnit));

        List<DataSetPoint> points = this.dataSetPointDao.findAndAggregatePointsGroupedByUnit(
                model.getTrainingDataSet(), startDatePoints, endDatePoints, timeUnits);
        Map<Date, DataSetPoint> pointMap = toMap(points);

        final Map<Date, Float> percentChangeMap = calculatePercentChange(points, pointMap);

        // should release these collections once percent change is calculated
        points = null;
        pointMap = null;

        final HibernateEntityManager hem = this.em.unwrap(HibernateEntityManager.class);
        final Session session = hem.getSession();

        final Long userId = model.getUserId();
        final DocumentState state = null;
        final String queryString = model.getTopic().getSearchExpression();
        final ScrollableResults scrollableResults = this.documentDao.scrollableSearch(userId, state, queryString,
                SortOrder.DATE_ASC, startDatePoints, endDatePoints);

        try {
            final EventStream eventStream = new DocumentScoreEventStream(model, scrollableResults, session,
                    percentChangeMap);
            final DataIndexer dataIndexer = new TwoPassDataIndexer(eventStream, this.cutoff);

            if (!logger.isDebugEnabled()) {
                GIS.PRINT_MESSAGES = false;
            }

            logger.debug("Invoking GIS.trainModel");
            final GISModel gisModel = GIS.trainModel(this.iterations, dataIndexer);
            logger.debug("GIS.trainModel complete");

            // because we've been clearing the entity manager's session
            model = this.modelDao.get(model.getId());
            writeModelToFile(model, gisModel);
            logger.info("Created model: " + gisModel);
            model.onTrainingComplete();

        } catch (final OutOfMemoryError e) {
            logError(model.getId(), e);
            throw e;
        } catch (final Exception e) {
            logError(model.getId(), e);
        } finally {
            scrollableResults.close();
        }
    }

    private void logError(final Long modelId, final Throwable t) {
        this.transactionTemplate.execute(new TransactionCallbackWithoutResult() {

            @Override
            protected void doInTransactionWithoutResult(final TransactionStatus status) {
                logger.error("Model " + modelId + " training failed: " + t, t);
                final Model model = OpenNLPMaxentModelTrainerTask.this.modelDao.get(modelId);
                model.setState(ModelState.ERROR);
            }
        });
    }

    private void writeModelToFile(final Model model, final GISModel gisModel) throws IOException {
        final File userFile = new File(this.baseDirectory, "" + model.getUserId());
        userFile.mkdirs();
        final File file = new File(userFile, model.getId() + BIN_GZ_EXT);
        final GISModelWriter writer = new SuffixSensitiveGISModelWriter(gisModel, file);

        try {
            logger.debug("Calling GISModelWriter.persist");
            writer.persist();
            logger.debug("GISModelWriter.persist complete");
        } finally {
            logger.debug("Calling GISModelWriter.close");
            writer.close();
            logger.debug("GISModelWriter closed");
        }
    }

    /**
     * Mark model as training in a separate transaction to ensure external
     * visibility
     * 
     * @return - id of marked model
     */
    private Long findAndMark() {
        return this.transactionTemplate.execute(new TransactionCallback<Long>() {
            @Override
            public Long doInTransaction(TransactionStatus txStatus) {
                return OpenNLPMaxentModelTrainerTask.this.modelDao.findAndMarkOne(ModelState.PENDING,
                        ModelState.TRAINING);
            }
        });
    }

    private class DocumentScoreEventStream implements EventStream {

        private static final int BATCH_SIZE = 25;

        private final Model model;
        private final Map<Date, Float> percentChangeMap;

        private final ScrollableResults scrollableResults;
        private final Session session;

        private int documentNumber = 0;
        private int pendingOutcomes = 0;
        private Document doc = null;
        private String[] context = new String[0];
        private float[] values = new float[0];

        public DocumentScoreEventStream(final Model model, final ScrollableResults scrollableResults,
                final Session session, final Map<Date, Float> percentChangeMap) {
            this.model = model;
            this.scrollableResults = scrollableResults;
            this.session = session;
            this.percentChangeMap = percentChangeMap;
        }

        @Override
        public Event next() throws IOException {
            final Event event;

            if (pendingOutcomes == 0) {
                pendingOutcomes = model.getLookahead();

                // [Document][Float]
                final Object[] object = this.scrollableResults.get();
                doc = (Document) object[0];
                documentNumber++;

                final Map<String, Integer> contextMap;

                // handle unusual case of stale index. Would be nice to fix...
                if (doc == null) {
                    logger.warn("Can't find document number " + documentNumber + "; index stale?");
                    contextMap = Collections.emptyMap();
                } else {
                    contextMap = OpenNLPMaxentContextBuilder.buildContext(doc);
                }

                context = new String[contextMap.size()];
                values = new float[contextMap.size()];
                OpenNLPMaxentContextBuilder.buildContextArrays(contextMap, context, values);
            }

            if (doc == null) {
                event = new Event("MISSING_DOCUMENT", context, values);
            } else {
                final String outcome = buildOutcome(model, model.getLookahead() - pendingOutcomes,
                        doc.getCreationDate(), percentChangeMap);
                event = new Event(outcome, context, values);
                pendingOutcomes--;

                if (logger.isDebugEnabled()) {
                    logger.debug(String.format("Event for model: %d, document ID: %d, outcome: %s", model.getId(),
                            doc.getId(), outcome));
                }
            }

            // clear out every BATCH_SIZE documents
            if (pendingOutcomes == 0 && documentNumber % BATCH_SIZE == 0) {
                session.flush();
                session.clear();
            }

            return event;
        }

        private String buildOutcome(final Model model, final int i, final Date creationDate,
                final Map<Date, Float> percentChangeMap) {

            final TimeUnits timeUnits = model.getTimeUnits();
            final Date baseTime = TimeUnits.floor(creationDate, timeUnits);
            final Date time = new Date(baseTime.getTime() + (i * timeUnits.getDuration()));

            double pctChange;

            if (percentChangeMap.containsKey(time)) {
                pctChange = percentChangeMap.get(time);
            } else {
                pctChange = 0.0d;
            }

            final double absPctChange = Math.abs(pctChange);

            // need discrete values
            for (int j = 0; j < Model.OUTCOME_ARRAY.length; j++) {
                if (absPctChange > Model.OUTCOME_ARRAY[j] && absPctChange <= Model.OUTCOME_ARRAY[j + 1]) {
                    final double closest = Model.OUTCOME_ARRAY[j];
                    pctChange = pctChange < 0.0d ? -1 * closest : closest;
                    break;
                }
            }

            final String outcome = String.format(OUTCOME_FORMAT, i, timeUnits.name(), pctChange);

            return outcome;
        }

        @Override
        public boolean hasNext() throws IOException {
            return scrollableResults.next() || pendingOutcomes > 0;
        }
    }
}