uk.ac.cam.eng.rule.retrieval.RuleRetriever.java Source code

Java tutorial

Introduction

Here is the source code for uk.ac.cam.eng.rule.retrieval.RuleRetriever.java

Source

/*******************************************************************************
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use these files except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * Copyright 2014 - Juan Pino, Aurelien Waite, William Byrne
 *******************************************************************************/
package uk.ac.cam.eng.rule.retrieval;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FilenameFilter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.zip.GZIPOutputStream;

import org.apache.commons.lang.time.StopWatch;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hbase.io.hfile.CacheConfig;
import org.apache.hadoop.hbase.io.hfile.HFile;
import org.apache.hadoop.hbase.util.BloomFilter;
import org.apache.hadoop.hbase.util.BloomFilterFactory;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Partitioner;
import org.apache.hadoop.mapreduce.lib.partition.HashPartitioner;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import uk.ac.cam.eng.extraction.datatypes.Rule;
import uk.ac.cam.eng.extraction.hadoop.datatypes.AlignmentAndFeatureMap;
import uk.ac.cam.eng.extraction.hadoop.datatypes.RuleWritable;
import uk.ac.cam.eng.extraction.hadoop.util.Util;
import uk.ac.cam.eng.rulebuilding.features.EnumRuleType;
import uk.ac.cam.eng.rulebuilding.features.FeatureCreator;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import com.beust.jcommander.Parameters;

/**
 * 
 * @author Juan Pino
 * @author Aurelien Waite
 * @date 28 May 2014
 * 
 * Rory's improvements to the retriever over Juan's original implementation
 * include: 
 * 1) A better search algorithm that does not query unigrams out of order of the main search 
 * 2) A more compact HFile format, which results in much smaller HFiles and faster access
 * 3) Multithreaded access to HFile partitions.
 */
public class RuleRetriever extends Configured implements Tool {

    private static final String RETRIEVEL_THREADS = "retrieval.threads";

    private BloomFilter[] bfs;

    private HFileRuleReader[] readers;

    private Partitioner<Text, NullWritable> partitioner = new HashPartitioner<>();

    RuleFilter filter;

    private String asciiConstraintsFileName;

    Set<RuleWritable> asciiConstraints;

    Set<RuleWritable> foundAsciiConstraints = new HashSet<>();

    Set<Text> testVocab;

    Set<Text> foundTestVocab = new HashSet<>();

    private int MAX_SOURCE_PHRASE;

    private void setup(String testFile, Configuration conf) throws FileNotFoundException, IOException {

        asciiConstraintsFileName = conf.get("pass_through_rules");
        String filterConfig = conf.get("filter_config");
        if (filterConfig == null) {
            System.err.println("Missing property 'filter_config' in the config");
            System.exit(1);
        }
        MAX_SOURCE_PHRASE = conf.getInt("max_source_phrase", 5);
        filter = new RuleFilter(conf);
        asciiConstraints = getAsciiConstraints();
        Set<Text> fullTestVocab = getTestVocab(testFile);
        Set<Text> asciiVocab = getAsciiVocab();
        testVocab = new HashSet<>();
        for (Text word : fullTestVocab) {
            if (!asciiVocab.contains(word)) {
                testVocab.add(word);
            }
        }

    }

    private void loadDir(String dirString, Configuration conf) throws IOException {
        File dir = new File(dirString);
        CacheConfig cacheConf = new CacheConfig(conf);
        if (!dir.isDirectory()) {
            throw new IOException(dirString + " is not a directory!");
        }
        File[] names = dir.listFiles(new FilenameFilter() {
            @Override
            public boolean accept(File dir, String name) {
                return name.endsWith("hfile");
            }
        });
        if (names.length == 0) {
            throw new IOException("No hfiles in " + dirString);
        }
        bfs = new BloomFilter[names.length];
        readers = new HFileRuleReader[names.length];
        for (File file : names) {
            String name = file.getName();
            int i = Integer.parseInt(name.substring(7, 12));
            HFile.Reader hfReader = HFile.createReader(FileSystem.getLocal(conf), new Path(file.getPath()),
                    cacheConf);
            bfs[i] = BloomFilterFactory.createFromMeta(hfReader.getBloomFilterMetadata(), hfReader);
            readers[i] = new HFileRuleReader(hfReader);
        }
    }

    private Set<RuleWritable> getAsciiConstraints() throws IOException {
        Set<RuleWritable> res = new HashSet<>();
        try (BufferedReader br = new BufferedReader(new FileReader(asciiConstraintsFileName))) {
            String line;
            Pattern regex = Pattern.compile(".*: (.*) # (.*)");
            Matcher matcher;
            while ((line = br.readLine()) != null) {
                matcher = regex.matcher(line);
                if (matcher.matches()) {
                    String[] sourceString = matcher.group(1).split(" ");
                    String[] targetString = matcher.group(2).split(" ");
                    if (sourceString.length != targetString.length) {
                        System.err.println("Malformed ascii constraint file: " + asciiConstraintsFileName);
                        System.exit(1);
                    }
                    List<Integer> source = new ArrayList<Integer>();
                    List<Integer> target = new ArrayList<Integer>();
                    int i = 0;
                    while (i < sourceString.length) {
                        if (i % MAX_SOURCE_PHRASE == 0 && i > 0) {
                            Rule rule = new Rule(-1, source, target);
                            res.add(new RuleWritable(rule));
                            source.clear();
                            target.clear();
                        }
                        source.add(Integer.parseInt(sourceString[i]));
                        target.add(Integer.parseInt(targetString[i]));
                        i++;
                    }
                    Rule rule = new Rule(-1, source, target);
                    res.add(new RuleWritable(rule));
                } else {
                    System.err.println("Malformed ascii constraint file: " + asciiConstraintsFileName);
                    System.exit(1);
                }
            }
        }
        return res;
    }

    private Set<Text> getAsciiVocab() throws IOException {
        // TODO simplify all template writing
        // TODO getAsciiVocab is redundant with getAsciiConstraints
        Set<Text> res = new HashSet<>();
        try (BufferedReader br = new BufferedReader(new FileReader(asciiConstraintsFileName))) {
            String line;
            Pattern regex = Pattern.compile(".*: (.*) # (.*)");
            Matcher matcher;
            while ((line = br.readLine()) != null) {
                matcher = regex.matcher(line);
                if (matcher.matches()) {
                    String[] sourceString = matcher.group(1).split(" ");
                    // only one word
                    if (sourceString.length == 1) {
                        res.add(new Text(sourceString[0]));
                    }
                } else {
                    System.err.println("Malformed ascii constraint file: " + asciiConstraintsFileName);
                    System.exit(1);
                }
            }
        }
        return res;
    }

    private Set<Text> getTestVocab(String testFile) throws FileNotFoundException, IOException {
        Set<Text> res = new HashSet<>();
        try (BufferedReader br = new BufferedReader(new FileReader(testFile))) {
            String line;
            while ((line = br.readLine()) != null) {
                String[] parts = line.split("\\s+");
                for (String part : parts) {
                    res.add(new Text(part));
                }
            }
        }
        return res;
    }

    public Collection<RuleWritable> getGlueRules() {
        List<RuleWritable> res = new ArrayList<>();
        List<Integer> sideGlueRule1 = new ArrayList<Integer>();
        sideGlueRule1.add(-4);
        sideGlueRule1.add(-1);
        Rule glueRule1 = new Rule(-4, sideGlueRule1, sideGlueRule1);
        res.add(new RuleWritable(glueRule1));
        List<Integer> sideGlueRule2 = new ArrayList<Integer>();
        sideGlueRule2.add(-1);
        Rule glueRule2 = new Rule(-1, sideGlueRule2, sideGlueRule2);
        res.add(new RuleWritable(glueRule2));
        List<Integer> sideGlueRule3 = new ArrayList<>();
        sideGlueRule3.add(-1);
        Rule glueRule3 = new Rule(-4, sideGlueRule3, sideGlueRule3);
        res.add(new RuleWritable(glueRule3));
        List<Integer> startSentenceSide = new ArrayList<Integer>();
        startSentenceSide.add(1);
        Rule startSentence = new Rule(-1, startSentenceSide, startSentenceSide);
        res.add(new RuleWritable(startSentence));
        List<Integer> endSentenceSide = new ArrayList<Integer>();
        endSentenceSide.add(2);
        Rule endSentence = new Rule(-1, endSentenceSide, endSentenceSide);
        res.add(new RuleWritable(endSentence));
        return res;
    }

    private List<Set<Text>> generateQueries(String testFileName, Configuration conf) throws IOException {
        PatternInstanceCreator patternInstanceCreator = new PatternInstanceCreator(conf);
        patternInstanceCreator.createSourcePatterns();
        List<Set<Text>> queries = new ArrayList<>(readers.length);
        for (int i = 0; i < readers.length; ++i) {
            queries.add(new HashSet<Text>());
        }
        try (BufferedReader reader = new BufferedReader(new FileReader(testFileName))) {
            int count = 0;
            for (String line = reader.readLine(); line != null; line = reader.readLine()) {
                StopWatch stopWatch = new StopWatch();
                stopWatch.start();
                Set<Rule> rules = patternInstanceCreator.createSourcePatternInstances(line);
                Collection<Text> sources = new ArrayList<>(rules.size());
                for (Rule rule : rules) {
                    Text source = (new RuleWritable(rule)).getSource();
                    sources.add(source);
                }
                for (Text source : sources) {
                    if (filter.filterSource(source)) {
                        continue;
                    }
                    int partition = partitioner.getPartition(source, null, queries.size());
                    queries.get(partition).add(source);
                }
                System.out.println("Creating patterns for line " + ++count + " took "
                        + (double) stopWatch.getTime() / 1000d + " seconds");
            }
        }
        return queries;
    }

    /**
     * Defines command line args.
     */
    @Parameters(separators = "=")
    public static class RuleRetrieverParameters {
        @Parameter(names = {
                "--max_source_phrase" }, description = "Maximum source phrase length in a phrase-based rule")
        public String max_source_phrase = "9";

        @Parameter(names = {
                "--max_source_elements" }, description = "Maximum number of source elements (terminals and nonterminals) in a hiero rule")
        public String max_source_elements = "5";

        @Parameter(names = {
                "--max_terminal_length" }, description = "Maximum number of consecutive terminals in a hiero rule")
        public String max_terminal_length = "5";

        @Parameter(names = {
                "--max_nonterminal_span" }, description = "Maximum number of source terminals covered by a right-hand-side source nonterminal in a hiero rule")
        public String max_nonterminal_span = "10";

        @Parameter(names = {
                "--hr_max_height" }, description = "Maximum number of source terminals covered by the left-hand-side nonterminal in a hiero rule")
        public String hr_max_height = "10";

        @Parameter(names = {
                "--mapreduce_features" }, description = "Comma-separated list of mapreduce features", required = true)
        public String mapreduce_features;

        @Parameter(names = { "--provenance" }, description = "Comma-separated list of provenances")
        public String provenance;

        @Parameter(names = {
                "--features" }, description = "Comma-separated list of features, including mapreduce features", required = true)
        public String features;

        @Parameter(names = { "--pass_through_rules" }, description = "File containing pass-through rules")
        public String pass_through_rules;

        @Parameter(names = {
                "--filter_config" }, description = "File containing additional filtering configuration, e.g. min source-to-target probability")
        public String filter_config;

        @Parameter(names = {
                "--source_patterns" }, description = "File containing a list of allowed source patterns")
        public String source_patterns;

        @Parameter(names = { "--ttable_s2t_server_port" }, description = "TTable source-to-target server port")
        public String ttable_s2t_server_port = "4949";

        @Parameter(names = { "--ttable_s2t_host" }, description = "TTable source-to-target host name")
        public String ttable_s2t_host = "localhost";

        @Parameter(names = { "--ttable_t2s_server_port" }, description = "TTable target-to-source server port")
        public String ttable_t2s_server_port = "9494";

        @Parameter(names = { "--ttable_t2s_host" }, description = "TTable target-to-source host name")
        public String ttable_t2s_host = "localhost";

        @Parameter(names = {
                "--retrieval_threads" }, description = "Number of threads for retrieval, corresponds to the number of hfiles")
        public String retrieval_threads;

        @Parameter(names = { "--hfile" }, description = "Directory containing the hfiles")
        public String hfile;

        @Parameter(names = { "--test_file" }, description = "File containing the sentences to be translated")
        public String test_file;

        @Parameter(names = { "--rules" }, description = "Output file containing filtered rules")
        public String rules;
    }

    /**
     * @param args
     * @throws IOException
     * @throws FileNotFoundException
     * @throws InterruptedException
     * @throws IllegalAccessException
     * @throws IllegalArgumentException
     */
    public int run(String[] args) throws FileNotFoundException, IOException, InterruptedException,
            IllegalArgumentException, IllegalAccessException {
        RuleRetrieverParameters params = new RuleRetrieverParameters();
        JCommander cmd = new JCommander(params);

        try {
            cmd.parse(args);
            Configuration conf = getConf();
            Util.ApplyConf(cmd, FeatureCreator.MAPRED_SUFFIX, conf);
            RuleRetriever retriever = new RuleRetriever();
            retriever.loadDir(params.hfile, conf);
            retriever.setup(params.test_file, conf);
            StopWatch stopWatch = new StopWatch();
            stopWatch.start();
            System.err.println("Generating query");
            List<Set<Text>> queries = retriever.generateQueries(params.test_file, conf);
            System.err.printf("Query took %d seconds to generate\n", stopWatch.getTime() / 1000);
            System.err.println("Executing queries");
            try (BufferedWriter out = new BufferedWriter(
                    new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(params.rules))))) {
                FeatureCreator features = new FeatureCreator(conf);
                ExecutorService threadPool = Executors.newFixedThreadPool(conf.getInt(RETRIEVEL_THREADS, 1));

                for (int i = 0; i < queries.size(); ++i) {
                    HFileRuleQuery query = new HFileRuleQuery(retriever.readers[i], retriever.bfs[i], out,
                            queries.get(i), features, retriever, conf);
                    threadPool.execute(query);
                }
                threadPool.shutdown();
                threadPool.awaitTermination(1, TimeUnit.DAYS);
                // Add ascii constraints not already found in query
                for (RuleWritable asciiConstraint : retriever.asciiConstraints) {
                    if (!retriever.foundAsciiConstraints.contains(asciiConstraint)) {
                        features.writeRule(asciiConstraint, AlignmentAndFeatureMap.EMPTY,
                                EnumRuleType.ASCII_OOV_DELETE, out);
                    }
                }
                // Add Deletetion and OOV rules
                RuleWritable deletionRuleWritable = new RuleWritable();
                deletionRuleWritable.setLeftHandSide(new Text(EnumRuleType.ASCII_OOV_DELETE.getLhs()));
                deletionRuleWritable.setTarget(new Text("0"));
                RuleWritable oovRuleWritable = new RuleWritable();
                oovRuleWritable.setLeftHandSide(new Text(EnumRuleType.ASCII_OOV_DELETE.getLhs()));
                oovRuleWritable.setTarget(new Text(""));
                for (Text source : retriever.testVocab) {
                    if (retriever.foundTestVocab.contains(source)) {
                        deletionRuleWritable.setSource(source);
                        features.writeRule(deletionRuleWritable, AlignmentAndFeatureMap.EMPTY,
                                EnumRuleType.ASCII_OOV_DELETE, out);
                    } else {
                        oovRuleWritable.setSource(source);
                        features.writeRule(oovRuleWritable, AlignmentAndFeatureMap.EMPTY,
                                EnumRuleType.ASCII_OOV_DELETE, out);
                    }
                }
                // Glue rules
                for (RuleWritable glueRule : retriever.getGlueRules()) {
                    features.writeRule(glueRule, AlignmentAndFeatureMap.EMPTY, EnumRuleType.GLUE, out);
                }
            }
            System.out.println(retriever.foundAsciiConstraints);
            System.out.println(retriever.foundTestVocab);
        } catch (ParameterException e) {
            System.err.println(e.getMessage());
            cmd.usage();
        }

        return 1;
    }

    public static void main(String[] args) throws Exception {
        int res = ToolRunner.run(new RuleRetriever(), args);
        System.exit(res);
    }
}