Java tutorial
/* Copyright 2012 Tim Garrett, Mothsoft LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.mothsoft.alexis.engine.predictive; import java.io.File; import java.io.IOException; import java.sql.Timestamp; import java.util.Collections; import java.util.Date; import java.util.List; import java.util.Map; import javax.persistence.EntityManager; import javax.persistence.PersistenceContext; import opennlp.maxent.GIS; import opennlp.maxent.GISModel; import opennlp.maxent.io.GISModelWriter; import opennlp.maxent.io.SuffixSensitiveGISModelWriter; import opennlp.model.DataIndexer; import opennlp.model.Event; import opennlp.model.EventStream; import opennlp.model.TwoPassDataIndexer; import org.apache.commons.lang.time.StopWatch; import org.apache.log4j.Logger; import org.hibernate.ScrollableResults; import org.hibernate.Session; import org.hibernate.ejb.HibernateEntityManager; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.TransactionDefinition; import org.springframework.transaction.TransactionStatus; import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.support.DefaultTransactionDefinition; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionCallbackWithoutResult; import org.springframework.transaction.support.TransactionTemplate; import com.mothsoft.alexis.dao.DataSetPointDao; import com.mothsoft.alexis.dao.DocumentDao; import com.mothsoft.alexis.dao.ModelDao; import com.mothsoft.alexis.domain.DataSetPoint; import com.mothsoft.alexis.domain.Document; import com.mothsoft.alexis.domain.DocumentState; import com.mothsoft.alexis.domain.Model; import com.mothsoft.alexis.domain.ModelState; import com.mothsoft.alexis.domain.SortOrder; import com.mothsoft.alexis.domain.TimeUnits; import com.mothsoft.alexis.engine.Task; public class OpenNLPMaxentModelTrainerTask extends AbstractModelTrainer implements ModelTrainer, Task { private static final Logger logger = Logger.getLogger(OpenNLPMaxentModelTrainerTask.class); private static final String OUTCOME_FORMAT = "+%d%s:%f"; private static final String BIN_GZ_EXT = ".bin.gz"; private DataSetPointDao dataSetPointDao; private DocumentDao documentDao; private ModelDao modelDao; private TransactionTemplate transactionTemplate; private File baseDirectory; private int iterations; private int cutoff; @PersistenceContext private EntityManager em; public OpenNLPMaxentModelTrainerTask() { super(); } public void setBaseDirectory(File baseDirectory) { this.baseDirectory = baseDirectory; } public void setDataSetPointDao(DataSetPointDao dataSetPointDao) { this.dataSetPointDao = dataSetPointDao; } public void setDocumentDao(final DocumentDao documentDao) { this.documentDao = documentDao; } public void setModelDao(ModelDao modelDao) { this.modelDao = modelDao; } public void setTransactionManager(final PlatformTransactionManager transactionManager) { final TransactionDefinition transactionDefinition = new DefaultTransactionDefinition( DefaultTransactionDefinition.PROPAGATION_REQUIRES_NEW); this.transactionTemplate = new TransactionTemplate(transactionManager, transactionDefinition); } public void setIterations(int iterations) { this.iterations = iterations; } public void setCutoff(int cutoff) { this.cutoff = cutoff; } protected DocumentDao getDocumentDao() { return this.documentDao; } @Transactional @Override public void execute() { final Long modelId = findAndMark(); if (modelId != null) { logger.info(String.format("Training model %d", modelId)); final StopWatch stopWatch = new StopWatch(); stopWatch.start(); final Model model = this.modelDao.get(modelId); train(model); stopWatch.stop(); logger.info(String.format("Training model %d took: %s", modelId, stopWatch.toString())); } } @Override public void train(Model model) { final int lookahead = model.getLookahead(); final TimeUnits timeUnits = model.getTimeUnits(); final long durationOfUnit = timeUnits.getDuration(); final Date startDate = TimeUnits.floor(model.getStartDate(), timeUnits); final Date endDate = TimeUnits.ceil(model.getEndDate(), timeUnits); final Timestamp startDatePoints = new Timestamp(startDate.getTime() - durationOfUnit); final Timestamp endDatePoints = new Timestamp(endDate.getTime() + (lookahead * durationOfUnit)); List<DataSetPoint> points = this.dataSetPointDao.findAndAggregatePointsGroupedByUnit( model.getTrainingDataSet(), startDatePoints, endDatePoints, timeUnits); Map<Date, DataSetPoint> pointMap = toMap(points); final Map<Date, Float> percentChangeMap = calculatePercentChange(points, pointMap); // should release these collections once percent change is calculated points = null; pointMap = null; final HibernateEntityManager hem = this.em.unwrap(HibernateEntityManager.class); final Session session = hem.getSession(); final Long userId = model.getUserId(); final DocumentState state = null; final String queryString = model.getTopic().getSearchExpression(); final ScrollableResults scrollableResults = this.documentDao.scrollableSearch(userId, state, queryString, SortOrder.DATE_ASC, startDatePoints, endDatePoints); try { final EventStream eventStream = new DocumentScoreEventStream(model, scrollableResults, session, percentChangeMap); final DataIndexer dataIndexer = new TwoPassDataIndexer(eventStream, this.cutoff); if (!logger.isDebugEnabled()) { GIS.PRINT_MESSAGES = false; } logger.debug("Invoking GIS.trainModel"); final GISModel gisModel = GIS.trainModel(this.iterations, dataIndexer); logger.debug("GIS.trainModel complete"); // because we've been clearing the entity manager's session model = this.modelDao.get(model.getId()); writeModelToFile(model, gisModel); logger.info("Created model: " + gisModel); model.onTrainingComplete(); } catch (final OutOfMemoryError e) { logError(model.getId(), e); throw e; } catch (final Exception e) { logError(model.getId(), e); } finally { scrollableResults.close(); } } private void logError(final Long modelId, final Throwable t) { this.transactionTemplate.execute(new TransactionCallbackWithoutResult() { @Override protected void doInTransactionWithoutResult(final TransactionStatus status) { logger.error("Model " + modelId + " training failed: " + t, t); final Model model = OpenNLPMaxentModelTrainerTask.this.modelDao.get(modelId); model.setState(ModelState.ERROR); } }); } private void writeModelToFile(final Model model, final GISModel gisModel) throws IOException { final File userFile = new File(this.baseDirectory, "" + model.getUserId()); userFile.mkdirs(); final File file = new File(userFile, model.getId() + BIN_GZ_EXT); final GISModelWriter writer = new SuffixSensitiveGISModelWriter(gisModel, file); try { logger.debug("Calling GISModelWriter.persist"); writer.persist(); logger.debug("GISModelWriter.persist complete"); } finally { logger.debug("Calling GISModelWriter.close"); writer.close(); logger.debug("GISModelWriter closed"); } } /** * Mark model as training in a separate transaction to ensure external * visibility * * @return - id of marked model */ private Long findAndMark() { return this.transactionTemplate.execute(new TransactionCallback<Long>() { @Override public Long doInTransaction(TransactionStatus txStatus) { return OpenNLPMaxentModelTrainerTask.this.modelDao.findAndMarkOne(ModelState.PENDING, ModelState.TRAINING); } }); } private class DocumentScoreEventStream implements EventStream { private static final int BATCH_SIZE = 25; private final Model model; private final Map<Date, Float> percentChangeMap; private final ScrollableResults scrollableResults; private final Session session; private int documentNumber = 0; private int pendingOutcomes = 0; private Document doc = null; private String[] context = new String[0]; private float[] values = new float[0]; public DocumentScoreEventStream(final Model model, final ScrollableResults scrollableResults, final Session session, final Map<Date, Float> percentChangeMap) { this.model = model; this.scrollableResults = scrollableResults; this.session = session; this.percentChangeMap = percentChangeMap; } @Override public Event next() throws IOException { final Event event; if (pendingOutcomes == 0) { pendingOutcomes = model.getLookahead(); // [Document][Float] final Object[] object = this.scrollableResults.get(); doc = (Document) object[0]; documentNumber++; final Map<String, Integer> contextMap; // handle unusual case of stale index. Would be nice to fix... if (doc == null) { logger.warn("Can't find document number " + documentNumber + "; index stale?"); contextMap = Collections.emptyMap(); } else { contextMap = OpenNLPMaxentContextBuilder.buildContext(doc); } context = new String[contextMap.size()]; values = new float[contextMap.size()]; OpenNLPMaxentContextBuilder.buildContextArrays(contextMap, context, values); } if (doc == null) { event = new Event("MISSING_DOCUMENT", context, values); } else { final String outcome = buildOutcome(model, model.getLookahead() - pendingOutcomes, doc.getCreationDate(), percentChangeMap); event = new Event(outcome, context, values); pendingOutcomes--; if (logger.isDebugEnabled()) { logger.debug(String.format("Event for model: %d, document ID: %d, outcome: %s", model.getId(), doc.getId(), outcome)); } } // clear out every BATCH_SIZE documents if (pendingOutcomes == 0 && documentNumber % BATCH_SIZE == 0) { session.flush(); session.clear(); } return event; } private String buildOutcome(final Model model, final int i, final Date creationDate, final Map<Date, Float> percentChangeMap) { final TimeUnits timeUnits = model.getTimeUnits(); final Date baseTime = TimeUnits.floor(creationDate, timeUnits); final Date time = new Date(baseTime.getTime() + (i * timeUnits.getDuration())); double pctChange; if (percentChangeMap.containsKey(time)) { pctChange = percentChangeMap.get(time); } else { pctChange = 0.0d; } final double absPctChange = Math.abs(pctChange); // need discrete values for (int j = 0; j < Model.OUTCOME_ARRAY.length; j++) { if (absPctChange > Model.OUTCOME_ARRAY[j] && absPctChange <= Model.OUTCOME_ARRAY[j + 1]) { final double closest = Model.OUTCOME_ARRAY[j]; pctChange = pctChange < 0.0d ? -1 * closest : closest; break; } } final String outcome = String.format(OUTCOME_FORMAT, i, timeUnits.name(), pctChange); return outcome; } @Override public boolean hasNext() throws IOException { return scrollableResults.next() || pendingOutcomes > 0; } } }