kea.KEAKeyphraseExtractor.java Source code

Java tutorial

Introduction

Here is the source code for kea.KEAKeyphraseExtractor.java

Source

/*
 *    KEAKeyphraseExtractor.java
 *    Copyright (C) 2001 Eibe Frank
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package kea;

import gate.util.BomStrippingInputStreamReader;

import java.io.*;
import java.util.*;

import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.core.FastVector;
import weka.core.Option;

/**
 * Extracts keyphrases from the documents in a given directory.
 * Assumes that the file names for the documents end with ".txt".
 * Puts extracted keyphrases into corresponding files ending with
 * ".key" (if those are not already present). Optionally an encoding
 * for the documents/keyphrases can be defined (e.g. for Chinese
 * text). Documents for which ".key" exists, are used for evaluation.
 *
 * Valid options are:<p>
 *
 * -l "directory name"<br>
 * Specifies name of directory.<p>
 *
 * -m "model name"<br>
 * Specifies name of model.<p>
 *
 * -e "encoding"<br>
 * Specifies encoding.<p>
 *
 * -n <br>
 * Specifies number of phrases to be output (default: 5).<p>
 *
 * -d<br>
 * Turns debugging mode on.<p>
 *
 * -a<br>
 * Also write stemmed phrase and score into ".key" file.<p>
 *
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version 1.0
 */
@SuppressWarnings({ "rawtypes", "unchecked", "cast", "resource" })
public class KEAKeyphraseExtractor implements OptionHandler {

    /** Name of directory */
    String m_dirName = null;

    /** Name of model */
    String m_modelName = null;

    /** Encoding */
    String m_encoding = "default";

    /** Debugging mode? */
    boolean m_debug = false;

    /** The KEA filter object */
    KEAFilter m_KEAFilter = null;

    /** The number of phrases to extract. */
    int m_numPhrases = 5;

    /** Also write stemmed phrase and score into .key file. */
    boolean m_AdditionalInfo = false;

    /**
     * Get the value of AdditionalInfo.
     *
     * @return Value of AdditionalInfo.
     */
    public boolean getAdditionalInfo() {

        return m_AdditionalInfo;
    }

    /**
     * Set the value of AdditionalInfo.
     *
     * @param newAdditionalInfo Value to assign to AdditionalInfo.
     */
    public void setAdditionalInfo(boolean newAdditionalInfo) {

        m_AdditionalInfo = newAdditionalInfo;
    }

    /**
     * Get the value of numPhrases.
     *
     * @return Value of numPhrases.
     */
    public int getNumPhrases() {

        return m_numPhrases;
    }

    /**
     * Set the value of numPhrases.
     *
     * @param newnumPhrases Value to assign to numPhrases.
     */
    public void setNumPhrases(int newnumPhrases) {

        m_numPhrases = newnumPhrases;
    }

    /**
     * Get the value of debug.
     *
     * @return Value of debug.
     */
    public boolean getDebug() {

        return m_debug;
    }

    /**
     * Set the value of debug.
     *
     * @param newdebug Value to assign to debug.
     */
    public void setDebug(boolean newdebug) {

        m_debug = newdebug;
    }

    /**
     * Get the value of encoding.
     *
     * @return Value of encoding.
     */
    public String getEncoding() {

        return m_encoding;
    }

    /**
     * Set the value of encoding.
     *
     * @param newencoding Value to assign to encoding.
     */
    public void setEncoding(String newencoding) {

        m_encoding = newencoding;
    }

    /**
     * Get the value of modelName.
     *
     * @return Value of modelName.
     */
    public String getModelName() {

        return m_modelName;
    }

    /**
     * Set the value of modelName.
     *
     * @param newmodelName Value to assign to modelName.
     */
    public void setModelName(String newmodelName) {

        m_modelName = newmodelName;
    }

    /**
     * Get the value of dirName.
     *
     * @return Value of dirName.
     */
    public String getDirName() {

        return m_dirName;
    }

    /**
     * Set the value of dirName.
     *
     * @param newdirName Value to assign to dirName.
     */
    public void setDirName(String newdirName) {

        m_dirName = newdirName;
    }

    /**
     * Parses a given list of options controlling the behaviour of this object.
     * Valid options are:<p>
     *
     * -l "directory name"<br>
     * Specifies name of directory.<p>
     *
     * -m "model name"<br>
     * Specifies name of model.<p>
     *
     * -e "encoding"<br>
     * Specifies encoding.<p>
     *
     * -n<br>
     * Specifies number of phrases to be output (default: 5).<p>
     *
     * -d<br>
     * Turns debugging mode on.<p>
     *
     * -a<br>
     * Also write stemmed phrase and score into ".key" file.<p>
     *
     * @param options the list of options as an array of strings
     * @exception Exception if an option is not supported
     */
    public void setOptions(String[] options) throws Exception {

        String dirName = Utils.getOption('l', options);
        if (dirName.length() > 0) {
            setDirName(dirName);
        } else {
            setDirName(null);
            throw new Exception("Name of directory required argument.");
        }
        String modelName = Utils.getOption('m', options);
        if (modelName.length() > 0) {
            setModelName(modelName);
        } else {
            setModelName(null);
            throw new Exception("Name of model required argument.");
        }
        String encoding = Utils.getOption('e', options);
        if (encoding.length() > 0) {
            setEncoding(encoding);
        } else {
            setEncoding("default");
        }
        String numPhrases = Utils.getOption('n', options);
        if (numPhrases.length() > 0) {
            setNumPhrases(Integer.parseInt(numPhrases));
        } else {
            setNumPhrases(5);
        }
        setDebug(Utils.getFlag('d', options));
        setAdditionalInfo(Utils.getFlag('a', options));
        Utils.checkForRemainingOptions(options);
    }

    /**
     * Gets the current option settings.
     *
     * @return an array of strings suitable for passing to setOptions
     */
    public String[] getOptions() {

        String[] options = new String[10];
        int current = 0;

        options[current++] = "-l";
        options[current++] = "" + (getDirName());
        options[current++] = "-m";
        options[current++] = "" + (getModelName());
        options[current++] = "-e";
        options[current++] = "" + (getEncoding());
        options[current++] = "-n";
        options[current++] = "" + (getNumPhrases());
        if (getDebug()) {
            options[current++] = "-d";
        }
        if (getAdditionalInfo()) {
            options[current++] = "-a";
        }

        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    /**
     * Returns an enumeration describing the available options.
     *
     * @return an enumeration of all the available options
     */
    public Enumeration listOptions() {

        Vector newVector = new Vector(6);

        newVector.addElement(new Option("\tSpecifies name of directory.", "l", 1, "-l <directory name>"));
        newVector.addElement(new Option("\tSpecifies name of model.", "m", 1, "-m <model name>"));
        newVector.addElement(new Option("\tSpecifies encoding.", "e", 1, "-e <encoding>"));
        newVector.addElement(new Option("\tSpecifies number of phrases to be output (default: 5).", "n", 1, "-n"));
        newVector.addElement(new Option("\tTurns debugging mode on.", "d", 0, "-d"));
        newVector.addElement(new Option("\tAlso write stemmed phrase and score into \".key\" file.", "a", 0, "-a"));

        return newVector.elements();
    }

    /**
     * Collects the stems of the file names.
     */
    public Hashtable collectStems() throws Exception {

        Hashtable stems = new Hashtable();

        try {
            File dir = new File(m_dirName);
            String[] files = dir.list();
            for (int i = 0; i < files.length; i++) {
                if (files[i].endsWith(".txt")) {
                    String stem = files[i].substring(0, files[i].length() - 4);
                    if (!stems.containsKey(stem)) {
                        stems.put(stem, new Double(0));
                    }
                }
            }
        } catch (Exception e) {
            throw new Exception("Problem opening directory " + m_dirName);
        }
        return stems;
    }

    /**
     * Builds the model from the files
     */
    public void extractKeyphrases(Hashtable stems) throws Exception {

        Vector stats = new Vector();

        // Check whether there is actually any data
        if (stems.size() == 0) {
            throw new Exception("Couldn't find any data!");
        }

        FastVector atts = new FastVector(2);
        atts.addElement(new Attribute("doc", (FastVector) null));
        atts.addElement(new Attribute("keyphrases", (FastVector) null));
        Instances data = new Instances("keyphrase_training_data", atts, 0);

        // Extract keyphrases
        Enumeration elem = stems.keys();
        while (elem.hasMoreElements()) {
            String str = (String) elem.nextElement();
            double[] newInst = new double[2];
            try {
                File txt = new File(m_dirName + "/" + str + ".txt");
                Reader is;
                if (!m_encoding.equals("default")) {
                    is = new BomStrippingInputStreamReader(new FileInputStream(txt), m_encoding);
                } else {
                    is = new BomStrippingInputStreamReader(new FileInputStream(txt));
                }
                StringBuffer txtStr = new StringBuffer();
                int c;
                while ((c = is.read()) != -1) {
                    txtStr.append((char) c);
                }
                newInst[0] = (double) data.attribute(0).addStringValue(txtStr.toString());
            } catch (Exception e) {
                if (m_debug) {
                    System.err.println("Can't read document " + str + ".txt");
                }
                newInst[0] = Instance.missingValue();
            }
            try {
                File key = new File(m_dirName + "/" + str + ".key");
                Reader is;
                if (!m_encoding.equals("default")) {
                    is = new BomStrippingInputStreamReader(new FileInputStream(key), m_encoding);
                } else {
                    is = new BomStrippingInputStreamReader(new FileInputStream(key));
                }
                StringBuffer keyStr = new StringBuffer();
                int c;
                while ((c = is.read()) != -1) {
                    keyStr.append((char) c);
                }
                newInst[1] = (double) data.attribute(1).addStringValue(keyStr.toString());
            } catch (Exception e) {
                if (m_debug) {
                    System.err.println("No keyphrases for stem " + str + ".");
                }
                newInst[1] = Instance.missingValue();
            }
            data.add(new Instance(1.0, newInst));
            m_KEAFilter.input(data.instance(0));
            data = data.stringFreeStructure();
            if (m_debug) {
                System.err.println("-- Document: " + str);
            }
            Instance[] topRankedInstances = new Instance[m_numPhrases];
            Instance inst;
            while ((inst = m_KEAFilter.output()) != null) {
                int index = (int) inst.value(m_KEAFilter.getRankIndex()) - 1;
                if (index < m_numPhrases) {
                    topRankedInstances[index] = inst;
                }
            }
            if (m_debug) {
                System.err.println("-- Keyphrases and feature values:");
            }
            FileOutputStream out = null;
            PrintWriter printer = null;
            File key = new File(m_dirName + "/" + str + ".key");
            if (!key.exists()) {
                out = new FileOutputStream(m_dirName + "/" + str + ".key");
                if (!m_encoding.equals("default")) {
                    printer = new PrintWriter(new OutputStreamWriter(out, m_encoding));
                } else {
                    printer = new PrintWriter(out);
                }
            }
            double numExtracted = 0, numCorrect = 0;
            for (int i = 0; i < m_numPhrases; i++) {
                if (topRankedInstances[i] != null) {
                    if (!topRankedInstances[i].isMissing(topRankedInstances[i].numAttributes() - 1)) {
                        numExtracted += 1.0;
                    }
                    if ((int) topRankedInstances[i]
                            .value(topRankedInstances[i].numAttributes() - 1) == topRankedInstances[i]
                                    .attribute(topRankedInstances[i].numAttributes() - 1).indexOfValue("True")) {
                        numCorrect += 1.0;
                    }
                    if (printer != null) {
                        printer.print(topRankedInstances[i].stringValue(m_KEAFilter.getUnstemmedPhraseIndex()));
                        if (m_AdditionalInfo) {
                            printer.print("\t");
                            printer.print(topRankedInstances[i].stringValue(m_KEAFilter.getStemmedPhraseIndex()));
                            printer.print("\t");
                            printer.print(Utils.doubleToString(
                                    topRankedInstances[i].value(m_KEAFilter.getProbabilityIndex()), 4));
                        }
                        printer.println();
                    }
                    if (m_debug) {
                        System.err.println(topRankedInstances[i]);
                    }
                }
            }
            if (numExtracted > 0) {
                if (m_debug) {
                    System.err.println("-- " + numCorrect + " correct");
                }
                stats.addElement(new Double(numCorrect));
            }
            if (printer != null) {
                printer.flush();
                printer.close();
                out.close();
            }
        }
        double[] st = new double[stats.size()];
        for (int i = 0; i < stats.size(); i++) {
            st[i] = ((Double) stats.elementAt(i)).doubleValue();
        }
        double avg = Utils.mean(st);
        double stdDev = Math.sqrt(Utils.variance(st));
        System.err.println("Avg. number of correct keyphrases: " + Utils.doubleToString(avg, 2) + " +/- "
                + Utils.doubleToString(stdDev, 2));
        System.err.println("Based on " + stats.size() + " documents");
        m_KEAFilter.batchFinished();
    }

    /**
     * Loads the extraction model from the file.
     */
    public void loadModel() throws Exception {

        BufferedInputStream inStream = new BufferedInputStream(new FileInputStream(m_modelName));
        ObjectInputStream in = new ObjectInputStream(inStream);
        m_KEAFilter = (KEAFilter) in.readObject();
        in.close();
    }

    /**
     * The main method.
     */
    public static void main(String[] ops) {

        KEAKeyphraseExtractor kmb = new KEAKeyphraseExtractor();
        try {
            kmb.setOptions(ops);
            System.err.print("Extracting keyphrases with options: ");
            String[] optionSettings = kmb.getOptions();
            for (int i = 0; i < optionSettings.length; i++) {
                System.err.print(optionSettings[i] + " ");
            }
            System.err.println();
            kmb.loadModel();
            kmb.extractKeyphrases(kmb.collectStems());
        } catch (Exception e) {
            e.printStackTrace();
            System.err.println(e.getMessage());
            System.err.println("\nOptions:\n");
            Enumeration enumeration = kmb.listOptions();
            while (enumeration.hasMoreElements()) {
                Option option = (Option) enumeration.nextElement();
                System.err.println(option.synopsis());
                System.err.println(option.description());
            }
        }
    }
}