Java tutorial
/* * exist-stanford-corenlp: XQuery module to integrate the Stanford CoreNLP * annotation pipeline library with eXist-db. * Copyright (C) 2016 ljo * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.exist.xquery.corenlp; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.IOException; import java.io.StringReader; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Properties; import edu.stanford.nlp.ie.AbstractSequenceClassifier; import edu.stanford.nlp.ie.crf.CRFClassifier; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.process.PTBTokenizer; import edu.stanford.nlp.process.WordToSentenceProcessor; import edu.stanford.nlp.sequences.SeqClassifierFlags; import edu.stanford.nlp.sequences.ColumnDocumentReaderAndWriter; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.LogManager; import org.apache.poi.POITextExtractor; import org.apache.poi.extractor.ExtractorFactory; import org.apache.poi.openxml4j.exceptions.InvalidFormatException; import org.apache.poi.openxml4j.exceptions.OpenXML4JException; import org.apache.poi.ss.usermodel.Cell; import org.apache.poi.ss.usermodel.CellStyle; import org.apache.poi.ss.usermodel.CreationHelper; import org.apache.poi.ss.usermodel.Font; import org.apache.poi.ss.usermodel.Row; import org.apache.poi.ss.usermodel.Workbook; import org.apache.poi.hssf.usermodel.HSSFWorkbook; import org.apache.poi.xssf.usermodel.XSSFWorkbook; import org.apache.xmlbeans.XmlException; import org.exist.dom.persistent.BinaryDocument; import org.exist.dom.persistent.DocumentImpl; import org.exist.dom.QName; import org.exist.dom.memtree.DocumentBuilderReceiver; import org.exist.dom.memtree.MemTreeBuilder; import org.exist.security.PermissionDeniedException; import org.exist.util.Configuration; import org.exist.util.ParametersExtractor; import org.exist.util.io.Resource; import org.exist.util.io.TemporaryFileManager; import org.exist.xmldb.XmldbURI; import org.exist.xquery.*; import org.exist.xquery.corenlp.util.DefaultBinaryValueManager; import org.exist.xquery.value.*; import org.xml.sax.SAXException; import org.jopendocument.dom.ODPackage; import org.jopendocument.dom.ODDocument; import org.jopendocument.dom.spreadsheet.Sheet; import org.jopendocument.dom.spreadsheet.SpreadSheet; public class TrainClassifier extends BasicFunction { private final static Logger LOG = LogManager.getLogger(TrainClassifier.class); public final static FunctionSignature signatures[] = { new FunctionSignature( new QName("train-classifier-spreadsheet-doc", StanfordCoreNLPModule.NAMESPACE_URI, StanfordCoreNLPModule.PREFIX), "Train a CRF classifier based on anntations in the provided spreadsheet document. Returns a serialiezed model to use in CRF classification.", new SequenceType[] { new FunctionParameterSequenceType("classifier", Type.STRING, Cardinality.EXACTLY_ONE, "The fully qualified name of an alternative classifier to load. Must be available on the classpath."), new FunctionParameterSequenceType("configuration", Type.ELEMENT, Cardinality.EXACTLY_ONE, "The training configuration, e.g. <parameters><param name='inputFormat' value='ods'/><param name='backgroundSymbol' value='O'/><param name='wordCol' value='0'/><param name='answerCol' value='1'/><param name='tagCol' value='2'/></parameters>. Available input formats ods (default), xlsx, xls or tsv."), new FunctionParameterSequenceType("uploaded-file", Type.BASE64_BINARY, Cardinality.ZERO_OR_ONE, "The uploaded file with your annotated text to train a classifier from. If no file is posted in the request you need to provide localFilePath in the configuration parameter.") }, new FunctionReturnSequenceType(Type.ITEM, Cardinality.EXACTLY_ONE, "The trained serialised CRF classifier")) }; private Path tempInFile = null; private Path tempOutFile = null; private BinaryValueFromBinaryString uploadedFileBase64String = null; private AnalyzeContextInfo cachedContextInfo; private Properties parameters = new Properties(); private enum InputDocType { ODS, XLSX, XLS, TSV }; private InputDocType inputFormat = InputDocType.ODS; private String backgroundSymbol = "O"; private String localFilePath = null; private int wordCol = 0; private int answerCol = 1; private int tagCol = -1; public TrainClassifier(XQueryContext context, FunctionSignature signature) { super(context, signature); } @Override public void analyze(AnalyzeContextInfo contextInfo) throws XPathException { cachedContextInfo = new AnalyzeContextInfo(contextInfo); super.analyze(cachedContextInfo); } @Override public Sequence eval(Sequence[] args, Sequence contextSequence) throws XPathException { String classifierClassPath = args[0].getStringValue(); context.pushDocumentContext(); try { Collection<List<CoreLabel>> documents; if (!args[1].isEmpty()) { parameters = ParametersExtractor.parseParameters(((NodeValue) args[1].itemAt(0)).getNode()); } if (!args[2].isEmpty()) { uploadedFileBase64String = new BinaryValueFromBinaryString(new Base64BinaryValueType(), args[2].getStringValue()); } tempInFile = TemporaryFileManager.getInstance().getTemporaryFile(); tempOutFile = TemporaryFileManager.getInstance().getTemporaryFile(); for (String property : parameters.stringPropertyNames()) { if ("inputFormat".equals(property)) { String value = parameters.getProperty(property); if ("ods".equals(value)) { inputFormat = InputDocType.ODS; } else if ("xlsx".equals(value)) { inputFormat = InputDocType.XLSX; } else if ("xls".equals(value)) { inputFormat = InputDocType.XLS; } else if ("tsv".equals(value)) { inputFormat = InputDocType.TSV; } } else if ("backgroundSymbol".equals(property)) { String value = parameters.getProperty(property); backgroundSymbol = value; } else if ("localFilePath".equals(property)) { String value = parameters.getProperty(property); localFilePath = value; } else if ("wordCol".equals(property)) { String value = parameters.getProperty(property); wordCol = Integer.valueOf(value); } else if ("answerCol".equals(property)) { String value = parameters.getProperty(property); answerCol = Integer.valueOf(value); } else if ("tagCol".equals(property)) { String value = parameters.getProperty(property); tagCol = Integer.valueOf(value); } } documents = readSpreadsheet(inputFormat); BinaryValueManager bvm = new DefaultBinaryValueManager(context); Base64BinaryDocument bvfis = null; if (documents.isEmpty()) { LOG.error("No annotated text extracted from the spreadsheet document!"); } else { trainClassifier(documents, inputFormat); bvfis = Base64BinaryDocument.getInstance(bvm, Files.newInputStream(tempOutFile)); } return bvfis; } catch (IOException ioe) { throw new XPathException(this, "Error while reading text document: " + ioe.getMessage(), ioe); } finally { context.popDocumentContext(); if (tempInFile != null) { TemporaryFileManager.getInstance().returnTemporaryFile(tempInFile); } if (tempOutFile != null) { TemporaryFileManager.getInstance().returnTemporaryFile(tempOutFile); } } } private void trainClassifier(Collection<List<CoreLabel>> documents, final InputDocType inputFormat) { final Properties props = new Properties(); // fixme! - check ocrTrain configurable under other name? //props.setProperty("ocrTrain", "true"); //props.setProperty("serializeTo", tempOutFile.toAbsolutePath().toString()); props.setProperty("useClassFeature", "true"); props.setProperty("useWord", "true"); props.setProperty("useNGrams", "true"); props.setProperty("noMidNGrams", "true"); props.setProperty("useDisjunctive", "true"); props.setProperty("maxNGramLeng", "6"); props.setProperty("usePrev", "true"); props.setProperty("useNext", "true"); props.setProperty("useSequences", "true"); props.setProperty("usePrevSequences", "true"); props.setProperty("maxLeft", "1"); props.setProperty("useTypeSeqs", "true"); props.setProperty("useTypeSeqs2", "true"); props.setProperty("useTypeySequences", "true"); props.setProperty("wordShape", "chris2useLC"); CRFClassifier<CoreLabel> classifier = new CRFClassifier(props); classifier.train(documents, new ColumnDocumentReaderAndWriter()); classifier.serializeClassifier(tempOutFile.toAbsolutePath().toString()); } private Collection<List<CoreLabel>> readSpreadsheet(final InputDocType inputFormat) throws XPathException { Collection<List<CoreLabel>> res = null; if (uploadedFileBase64String == null && localFilePath == null) { res = readODSSpreadsheet("/db/temp/swe-clarin/user-annotated.ods"); } else { switch (inputFormat) { case ODS: res = readODSSpreadsheet(localFilePath); break; case XLSX: res = readXLSXSpreadsheet(localFilePath, inputFormat); break; case XLS: res = readXLSXSpreadsheet(localFilePath, inputFormat); break; case TSV: res = readTSVSpreadsheet(localFilePath); break; } } return res; } private Collection<List<CoreLabel>> readODSSpreadsheet(final String localFilePath) throws XPathException { Collection<List<CoreLabel>> documents = new ArrayList<>(); List<CoreLabel> document = new ArrayList<>(); //try (InputStream is = Files.newInputStream(tempInFile)) { try (InputStream is = uploadedFileBase64String != null ? uploadedFileBase64String.getInputStream() : new Resource(localFilePath).getInputStream()) { SpreadSheet spreadSheet = ODPackage.createFromStream(is, "UserAnnotatedDocument").getSpreadSheet(); Sheet sheet = spreadSheet.getSheet(0); for (int i = 0; i < sheet.getRowCount(); i++) { CoreLabel row = new CoreLabel(); String value1 = sheet.getValueAt(0, i).toString(); String value2 = sheet.getValueAt(1, i).toString(); row.setWord(value1); row.setNER(value2); row.set(CoreAnnotations.AnswerAnnotation.class, value2); if (sheet.getColumnCount() > 2) { String value3 = sheet.getValueAt(2, i).toString(); if (!"".equals(value3) && tagCol > -1) { row.setTag(value3); } } if (!"".equals(value1)) { document.add(row); } else { documents.add(document); document = new ArrayList<>(); } } } catch (IOException ioe) { throw new XPathException(this, "Error while reading spreadsheet document: " + ioe.getMessage(), ioe); } return documents; } private Collection<List<CoreLabel>> readXLSXSpreadsheet(final String localFilePath, final InputDocType inputFormat) throws XPathException { Workbook workbook = null; Collection<List<CoreLabel>> documents = new ArrayList<>(); List<CoreLabel> document = new ArrayList<>(); String fileName = "localFilePath"; String extraSuffix = (inputFormat != InputDocType.XLSX) ? "" : "x"; //try (InputStream is = Files.newInputStream(tempInFile)) { try (InputStream is = uploadedFileBase64String == null ? uploadedFileBase64String.getInputStream() : new Resource(fileName + extraSuffix).getInputStream()) { if (inputFormat == InputDocType.XLSX) { workbook = new XSSFWorkbook(is); } else { workbook = new HSSFWorkbook(is); } } catch (FileNotFoundException fe) { LOG.error(fe); } catch (IOException ioe) { LOG.error(ioe); throw new XPathException(this, "Error while reading spreadsheet document: " + ioe.getMessage(), ioe); } org.apache.poi.ss.usermodel.Sheet sheet = workbook.getSheetAt(0); Row row; Cell cell; Iterator rows = sheet.rowIterator(); while (rows.hasNext()) { CoreLabel tok = new CoreLabel(); row = (Row) rows.next(); Iterator cells = row.cellIterator(); int cellPos = 0; while (cells.hasNext()) { cell = (Cell) cells.next(); //if (cell.getCellType() == Cell.CELL_TYPE_STRING) { switch (cellPos) { case 0: tok.setWord(cell.getStringCellValue()); break; case 1: tok.setNER(cell.getStringCellValue()); tok.set(CoreAnnotations.AnswerAnnotation.class, cell.getStringCellValue()); break; case 2: tok.setTag(cell.getStringCellValue()); break; default: break; } //} else if(cell.getCellType() == Cell.CELL_TYPE_NUMERIC) { //LOG.error("Cell has numeric value:" + cell.getNumericCellValue()); //} cellPos++; } if (!"".equals(tok.word())) { document.add(tok); } else { documents.add(document); document = new ArrayList<>(); } } return documents; } private Collection<List<CoreLabel>> readTSVSpreadsheet(final String localFilePath) throws XPathException { String separator = "\t"; String line; Collection<List<CoreLabel>> documents = new ArrayList<>(); List<CoreLabel> document = new ArrayList<>(); //try (BufferedReader tsv = Files.newBufferedReader(tempInFile)) { try (BufferedReader tsv = uploadedFileBase64String == null ? new BufferedReader(new InputStreamReader(uploadedFileBase64String.getInputStream(), "UTF-8")) : new Resource(localFilePath).getBufferedReader()) { while ((line = tsv.readLine()) != null) { CoreLabel tok = new CoreLabel(); List<String> cells = Arrays.asList(line.split(separator)); if (cells.size() > 0 && !"".equals(cells.get(0))) { tok.setWord(cells.get(0)); tok.setNER(cells.get(1)); tok.set(CoreAnnotations.AnswerAnnotation.class, cells.get(1)); if (cells.size() > 2 && !"".equals(cells.get(2))) { tok.setTag(cells.get(2)); } document.add(tok); } else { documents.add(document); document = new ArrayList<>(); } } } catch (FileNotFoundException fe) { LOG.error(fe); } catch (IOException ioe) { LOG.error(ioe); throw new XPathException(this, "Error while reading spreadsheet document: " + ioe.getMessage(), ioe); } return documents; } }