cz.vse.fis.keg.entityclassifier.core.salience.EntitySaliencer.java Source code

Java tutorial

Introduction

Here is the source code for cz.vse.fis.keg.entityclassifier.core.salience.EntitySaliencer.java

Source

/*
 * #%L
 * Entityclassifier.eu NER CORE v3.9
 * %%
 * Copyright (C) 2015 Knowledge Engineering Group (KEG) and Web Intelligence Research Group (WIRG) - Milan Dojchinovski (milan.dojchinovski@fit.cvut.cz)
 * %%
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public
 * License along with this program.  If not, see
 * <http://www.gnu.org/licenses/gpl-3.0.html>.
 * #L%
 */
package cz.vse.fis.keg.entityclassifier.core.salience;

import cz.vse.fis.keg.entityclassifier.core.THDController;
import cz.vse.fis.keg.entityclassifier.core.conf.Settings;
import cz.vse.fis.keg.entityclassifier.core.vao.Entity;
import cz.vse.fis.keg.entityclassifier.core.vao.Salience;
import cz.vse.fis.keg.entityclassifier.core.vao.Type;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.net.URL;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/**
 *
 * @author Milan Dojchinovski <milan.dojchinovski@fit.cvut.cz>
 * http://dojchinovski.mk
 */
public class EntitySaliencer {

    private static EntitySaliencer instance = null;
    private boolean initialized = false;

    //    private NaiveBayes classifier = null;
    private RandomForest classifier = null;

    public static EntitySaliencer getInstance() {
        if (instance == null) {
            instance = new EntitySaliencer();
        }
        return instance;
    }

    public void initialize() {
        trainModel();
    }

    public void computeSalience(List<Entity> entities) {
        try {
            if (!initialized) {
                initialize();
                initialized = true;
            }

            ArrayList<SEntity> processedEntities = new ArrayList<SEntity>();

            for (Entity e : entities) {
                SEntity entityMention = new SEntity();
                entityMention.setBeginIndex(e.getStartOffset().intValue());
                entityMention.setEntityType(e.getEntityType());

                ArrayList<Type> types = e.getTypes();
                ArrayList<String> loggedURIs = new ArrayList<String>();

                if (types != null) {
                    for (Type t : types) {
                        String entityURI = t.getEntityURI();

                        if (!loggedURIs.contains(entityURI)) {
                            loggedURIs.add(entityURI);
                            entityMention.getUrls().add(entityURI);
                        }
                    }
                }

                boolean entityAlreadyLogged = false;

                for (SEntity sEntity : processedEntities) {
                    boolean isThisEntitySame = false;
                    ArrayList<String> entityURIs1 = sEntity.getUrls();
                    ArrayList<String> entityURIs2 = entityMention.getUrls();

                    for (String eURI1 : entityURIs1) {
                        for (String eURI2 : entityURIs2) {
                            if (!entityAlreadyLogged) {
                                if (eURI1.equals(eURI2)) {
                                    entityAlreadyLogged = true;
                                    isThisEntitySame = true;
                                    sEntity.setNumOccurrences(sEntity.getNumOccurrences() + 1);
                                }
                            }
                        }
                    }

                    if (isThisEntitySame) {
                        for (String uri : entityMention.getUrls()) {
                            if (!sEntity.getUrls().contains(uri)) {
                                sEntity.getUrls().add(uri);
                            }
                        }
                    }
                }

                // Entity seen for first time in the document.
                if (!entityAlreadyLogged) {
                    entityMention.setNumOccurrences(1);
                    processedEntities.add(entityMention);
                }
            }

            // Preparing the test data container.
            FastVector attributes = new FastVector(6);
            attributes.add(new Attribute("beginIndex"));
            attributes.add(new Attribute("numUniqueEntitiesInDoc"));
            attributes.add(new Attribute("numOfOccurrencesOfEntityInDoc"));
            attributes.add(new Attribute("numOfEntityMentionsInDoc"));

            FastVector entityTypeNominalAttVal = new FastVector(2);
            entityTypeNominalAttVal.addElement("named_entity");
            entityTypeNominalAttVal.addElement("common_entity");

            Attribute entityTypeAtt = new Attribute("type", entityTypeNominalAttVal);
            attributes.add(entityTypeAtt);
            FastVector classNominalAttVal = new FastVector(3);
            classNominalAttVal.addElement("not_salient");
            classNominalAttVal.addElement("less_salient");
            classNominalAttVal.addElement("most_salient");
            Attribute classAtt = new Attribute("class", classNominalAttVal);
            attributes.add(classAtt);
            Instances evalData = new Instances("MyRelation", attributes, 0);

            evalData.setClassIndex(evalData.numAttributes() - 1);

            for (int i = 0; i < processedEntities.size(); i++) {

                String entityType = "";
                if (processedEntities.get(i).getEntityType().equals("named entity")) {
                    entityType = "named_entity";
                } else if (processedEntities.get(i).getEntityType().equals("common entity")) {
                    entityType = "common_entity";
                } else {
                }
                Instance inst = new DenseInstance(6);
                inst.setValue(evalData.attribute(0), processedEntities.get(i).getBeginIndex()); // begin index
                inst.setValue(evalData.attribute(1), processedEntities.size()); // num of unique entities in doc
                inst.setValue(evalData.attribute(2), processedEntities.get(i).getNumOccurrences()); // num of entity occurrences in doc
                inst.setValue(evalData.attribute(3), entities.size()); // num of entity mentions in doc
                inst.setValue(evalData.attribute(4), entityType); // type of the entity
                evalData.add(inst);

            }

            for (int i = 0; i < processedEntities.size(); i++) {
                SEntity sEntity = processedEntities.get(i);
                int classIndex = (int) classifier.classifyInstance(evalData.get(i));
                String classLabel = evalData.firstInstance().classAttribute().value(classIndex);
                double pred[] = classifier.distributionForInstance(evalData.get(i));
                double probability = pred[classIndex];

                double salienceScore = pred[1] * 0.5 + pred[2];
                sEntity.setSalienceScore(salienceScore);
                sEntity.setSalienceConfidence(probability);
                sEntity.setSalienceClass(classLabel);

                for (Entity e : entities) {
                    ArrayList<Type> types = e.getTypes();
                    if (types != null) {
                        for (Type t : types) {
                            if (sEntity.getUrls().contains(t.getEntityURI())) {
                                Salience s = new Salience();
                                s.setClassLabel(classLabel);
                                DecimalFormat df = new DecimalFormat("0.000");
                                double fProbability = df.parse(df.format(probability)).doubleValue();
                                double fSalience = df.parse(df.format(salienceScore)).doubleValue();
                                s.setConfidence(fProbability);
                                s.setScore(fSalience);
                                t.setSalience(s);
                            }
                        }
                    }
                }
            }

        } catch (Exception ex) {
            Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    private void trainModel() {

        BufferedReader reader = null;

        try {

            URL fileURL = THDController.getInstance().getClass().getResource(Settings.SALIENCE_DATASET);
            File arrfFile = new File(fileURL.getFile());

            reader = new BufferedReader(new FileReader(arrfFile));
            Instances data = new Instances(reader);
            data.setClassIndex(data.numAttributes() - 1);

            //            classifier = new NaiveBayes();
            classifier = new RandomForest();

            // Train the classifer.
            classifier.buildClassifier(data);

        } catch (FileNotFoundException ex) {
            Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
        } catch (IOException ex) {
            Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
        } catch (Exception ex) {
            Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
        } finally {
            try {
                reader.close();
                System.out.println("Model was successfully trained.");
            } catch (IOException ex) {
                Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }
}