Java tutorial
/* Copyright 2012 Tim Garrett, Mothsoft LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.mothsoft.alexis.engine.predictive; import java.io.File; import java.sql.Timestamp; import java.text.DateFormat; import java.util.ArrayList; import java.util.Date; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import opennlp.maxent.io.SuffixSensitiveGISModelReader; import opennlp.model.AbstractModel; import opennlp.model.MaxentModel; import org.apache.commons.lang.time.StopWatch; import org.apache.log4j.Logger; import org.hibernate.ScrollableResults; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.TransactionDefinition; import org.springframework.transaction.TransactionStatus; import org.springframework.transaction.support.DefaultTransactionDefinition; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; import com.mothsoft.alexis.dao.DataSetPointDao; import com.mothsoft.alexis.dao.DocumentDao; import com.mothsoft.alexis.dao.ModelDao; import com.mothsoft.alexis.domain.DataSetPoint; import com.mothsoft.alexis.domain.Document; import com.mothsoft.alexis.domain.Model; import com.mothsoft.alexis.domain.ModelState; import com.mothsoft.alexis.domain.ModelType; import com.mothsoft.alexis.domain.SortOrder; import com.mothsoft.alexis.domain.TimeUnits; import com.mothsoft.alexis.engine.Task; public class OpenNLPMaxentModelExecutorTask implements Task { private static final Logger logger = Logger.getLogger(OpenNLPMaxentModelExecutorTask.class); private static final Pattern OUTCOME_PATTERN = Pattern.compile("\\+(\\d+)(\\S+)\\:(\\S+)"); private static final String BIN_GZ_EXT = ".bin.gz"; private File baseDirectory; private DataSetPointDao dataSetPointDao; private DocumentDao documentDao; private ModelDao modelDao; private TransactionTemplate transactionTemplate; public OpenNLPMaxentModelExecutorTask() { super(); } public void setBaseDirectory(File baseDirectory) { this.baseDirectory = baseDirectory; } public void setDataSetPointDao(final DataSetPointDao dataSetPointDao) { this.dataSetPointDao = dataSetPointDao; } public void setDocumentDao(DocumentDao documentDao) { this.documentDao = documentDao; } public void setModelDao(final ModelDao modelDao) { this.modelDao = modelDao; } public void setTransactionManager(final PlatformTransactionManager transactionManager) { final TransactionDefinition transactionDefinition = new DefaultTransactionDefinition( DefaultTransactionDefinition.PROPAGATION_REQUIRES_NEW); this.transactionTemplate = new TransactionTemplate(transactionManager, transactionDefinition); } @Override public void execute() { final StopWatch stopWatch = new StopWatch(); stopWatch.start(); final List<Long> modelIds = findModelsToExecute(); final int size = modelIds.size(); logger.info(String.format("Found %d models in state READY", size)); int executed = 0; for (final Long modelId : modelIds) { boolean success = execute(modelId); if (success) { executed++; } } stopWatch.stop(); logger.info(String.format("Executed %d of %d models, took: %s", executed, size, stopWatch.toString())); } private List<Long> findModelsToExecute() { return this.transactionTemplate.execute(new TransactionCallback<List<Long>>() { @Override public List<Long> doInTransaction(TransactionStatus arg0) { final List<Model> models = OpenNLPMaxentModelExecutorTask.this.modelDao .findByTypeAndState(ModelType.MAXENT, ModelState.READY); final List<Long> modelIds = new ArrayList<Long>(models.size()); for (final Model model : models) { modelIds.add(model.getId()); } return modelIds; } }); } private boolean execute(final Long modelId) { return this.transactionTemplate.execute(new TransactionCallback<Boolean>() { @Override public Boolean doInTransaction(TransactionStatus arg0) { final Model model = OpenNLPMaxentModelExecutorTask.this.modelDao.get(modelId); return OpenNLPMaxentModelExecutorTask.this.doExecute(model); } }); } private boolean doExecute(final Model model) { final StopWatch stopWatch = new StopWatch(); stopWatch.start(); boolean result = false; try { logger.info(String.format("Executing model %d", model.getId())); // load model file final File userDirectory = new File(baseDirectory, "" + model.getUserId()); final File modelFile = new File(userDirectory, model.getId() + BIN_GZ_EXT); final AbstractModel maxentModel = new SuffixSensitiveGISModelReader(modelFile).getModel(); final Date now = new Date(); final TimeUnits timeUnits = model.getTimeUnits(); final Timestamp topOfPeriod = new Timestamp(TimeUnits.floor(now, timeUnits).getTime()); final Timestamp endOfPeriod = new Timestamp(topOfPeriod.getTime() + timeUnits.getDuration() - 1); // first position: sum of changes predicted, second position: number // of samples--will calculate a boring old mean... final double[][] changeByPeriod = new double[model.getLookahead()][2]; // initialize for (int i = 0; i < changeByPeriod.length; i++) { changeByPeriod[i][0] = 0.0d; changeByPeriod[i][1] = 0.0d; } // find the most recent point value // FIXME - some sparse data sets may require executing the model on // all documents since that point or applying some sort of // dead-reckoning logic for smoothing final DataSetPoint initial = this.dataSetPointDao.findLastPointBefore(model.getTrainingDataSet(), endOfPeriod); // let's get the corner cases out of the way if (initial == null) { logger.warn("Insufficient data to execute model!"); return false; } // happy path // build consolidated context of events in this period // find current value of training data set for this period final double[] probs = eval(model, topOfPeriod, endOfPeriod, maxentModel); // predict from the last available point, adjusted for time // remaining in period final double y0 = initial.getY(); // map outcomes to periods in the future (at least no earlier than // this period) for (int i = 0; i < probs.length; i++) { // in the form +nU:+/-x, where n is the number of periods, U is // the unit type for the period, +/- is the direction, and x is // a discrete value from Model.OUTCOME_ARRAY final String outcome = maxentModel.getOutcome(i); final Matcher matcher = OUTCOME_PATTERN.matcher(outcome); if (!matcher.matches()) { logger.warn("Can't process outcome: " + outcome + "; skipping"); continue; } final int period = Integer.valueOf(matcher.group(1)); final String units = matcher.group(2); final double percentChange = Double.valueOf(matcher.group(3)); // record the observation and the count of observations changeByPeriod[period][0] += percentChange; changeByPeriod[period][1] += 1.0d; if (logger.isDebugEnabled()) { final double yi = y0 * (1 + percentChange); logger.debug(String.format("Outcome: %s, %s: +%d, change: %f, new value: %f, probability: %f", outcome, units, period, percentChange, yi, probs[i])); } } // build points for predictive data set double yn = y0; // we need to track the points and remove any that were not // predicted by this execution of the model final Timestamp endOfPredictionRange = new Timestamp( topOfPeriod.getTime() + (changeByPeriod.length * timeUnits.getDuration())); final List<DataSetPoint> existingPoints = this.dataSetPointDao .findByTimeRange(model.getPredictionDataSet(), topOfPeriod, endOfPredictionRange); for (int period = 0; period < changeByPeriod.length; period++) { final double totalPercentChange = changeByPeriod[period][0]; final double sampleCount = changeByPeriod[period][1]; double percentChange; if (totalPercentChange == 0.0d || sampleCount == 0.0d) { percentChange = 0.0d; } else { percentChange = totalPercentChange / sampleCount; } // apply adjustments only if the initial point is within the // time period, and only for the first time period boolean applyAdjustment = period == 0 && topOfPeriod.before(initial.getX()); if (applyAdjustment) { final double adjustmentFactor = findAdjustmentFactor(initial.getX(), timeUnits); percentChange = (totalPercentChange / sampleCount) * adjustmentFactor; } // figure out the next value and coerce to a sane number of // decimal places (2); final double newValue = (double) Math.round(yn * (1.0d + percentChange) * 100) / 100; final Timestamp timestamp = new Timestamp( topOfPeriod.getTime() + (period * timeUnits.getDuration())); if (logger.isDebugEnabled()) { logger.debug(String.format("Model %d for data set %d predicted point: (%s, %f)", model.getId(), model.getTrainingDataSet().getId(), DateFormat.getInstance().format(timestamp), newValue)); } DataSetPoint ithPoint = this.dataSetPointDao.findByTimestamp(model.getPredictionDataSet(), timestamp); // conditionally create if (ithPoint == null) { ithPoint = new DataSetPoint(model.getPredictionDataSet(), timestamp, newValue); this.dataSetPointDao.add(ithPoint); } else { // or update ithPoint.setY(newValue); // updated points retained, other existing removed existingPoints.remove(ithPoint); } // store current and use as starting point for next iteration yn = newValue; } // remove stale points from an old model execution for (final DataSetPoint toRemove : existingPoints) { this.dataSetPointDao.remove(toRemove); } result = true; } catch (final Exception e) { logger.warn("Model " + model.getId() + " failed with: " + e, e); result = false; } finally { stopWatch.stop(); logger.info(String.format("Executing model %d took %s", model.getId(), stopWatch.toString())); } return result; } /** * Returns 1 - the percentage of time period completed. This applies the * percent change predicted uniformly over the time period * */ private double findAdjustmentFactor(final Date date, final TimeUnits timeUnits) { final Date floor = TimeUnits.floor(date, timeUnits); final double dividend = (double) (date.getTime() - floor.getTime()); final double divisor = (double) timeUnits.getDuration(); final double percentTimeComplete = dividend / divisor; return 1.0d - percentTimeComplete; } private double[] eval(final Model model, final Timestamp topOfPeriod, final Timestamp endOfPeriod, final MaxentModel maxentModel) { final ScrollableResults scrollableResults = this.documentDao.scrollableSearch(model.getUserId(), null, model.getTopic().getSearchExpression(), SortOrder.DATE_ASC, topOfPeriod, endOfPeriod); // initialize with an estimated size to prevent a lot of resizing final Map<String, Integer> contextMap = new LinkedHashMap<String, Integer>(64 * 1024); try { while (scrollableResults.next()) { final Object[] row = scrollableResults.get(); final Document document = (Document) row[0]; if (document == null) { // caused by stale index continue; } else { OpenNLPMaxentContextBuilder.append(contextMap, document); } } } finally { scrollableResults.close(); } final String[] context = new String[contextMap.size()]; final float[] values = new float[contextMap.size()]; // copy map to arrays OpenNLPMaxentContextBuilder.buildContextArrays(contextMap, context, values); // eval return maxentModel.eval(context, values); } }