uk.ac.cam.eng.extraction.hadoop.features.lexical.TTableServer.java Source code

Java tutorial

Introduction

Here is the source code for uk.ac.cam.eng.extraction.hadoop.features.lexical.TTableServer.java

Source

/*******************************************************************************
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use these files except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * Copyright 2014 - Juan Pino, Aurelien Waite, William Byrne
 *******************************************************************************/
package uk.ac.cam.eng.extraction.hadoop.features.lexical;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.zip.GZIPInputStream;

import org.apache.commons.lang.time.StopWatch;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import uk.ac.cam.eng.extraction.hadoop.datatypes.ProvenanceCountMap;
import uk.ac.cam.eng.extraction.hadoop.util.Util;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import com.beust.jcommander.Parameters;

/**
 * 
 * @author Aurelien Waite
 * @date 28 May 2014
 */
public class TTableServer extends Configured implements Closeable, Tool {

    final static int BUFFER_SIZE = 65536;

    private ExecutorService threadPool = Executors.newFixedThreadPool(6);

    private class LoadTask implements Runnable {

        private final String fileName;
        private final byte prov;

        private LoadTask(String fileName, byte prov) {
            this.fileName = fileName;
            this.prov = prov;
        }

        @Override
        public void run() {
            try {
                loadModel(fileName, prov);
            } catch (IOException e) {
                e.printStackTrace();
                System.exit(1);
            }

        }

    }

    private class QueryRunnable implements Runnable {

        private Socket querySocket;

        private ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(BUFFER_SIZE);

        private DataOutputStream probWriter = new DataOutputStream(byteBuffer);

        private long queryTime = 0;

        private long totalKeys = 0;

        private int noOfQueries = 0;

        private QueryRunnable(Socket querySocket) {
            this.querySocket = querySocket;
        }

        @Override
        public void run() {
            try {
                runWithExceptions();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        private void runWithExceptions() throws IOException {
            try (DataInputStream queryReader = new DataInputStream(
                    new BufferedInputStream(querySocket.getInputStream()))) {
                try (OutputStream out = querySocket.getOutputStream()) {
                    StopWatch stopWatch = new StopWatch();
                    // A bit nasty, but will block on the readInt.
                    // It's not really polling. Honest!
                    try {
                        int querySize = queryReader.readInt();
                        totalKeys += querySize;
                        stopWatch.start();
                        for (int i = 0; i < querySize; ++i) {
                            int provInt = queryReader.readInt();
                            byte prov = (byte) provInt;
                            int source = queryReader.readInt();
                            int target = queryReader.readInt();
                            if (model.containsKey(prov) && model.get(prov).containsKey(source)
                                    && model.get(prov).get(source).containsKey(target)) {
                                probWriter.writeDouble(model.get(prov).get(source).get(target));
                            } else {
                                probWriter.writeDouble(Double.MAX_VALUE);
                            }
                        }
                        byteBuffer.writeTo(out);
                        byteBuffer.reset();
                        stopWatch.stop();
                        queryTime += stopWatch.getTime();
                        if (++noOfQueries == 1000) {
                            System.out.println("Time per key = " + (double) queryTime / (double) totalKeys);
                            noOfQueries = 0;
                            queryTime = totalKeys = 0;
                        }
                    } catch (EOFException e) {
                        System.out.println("Connection from mapper closed");
                    }
                }
            }
            querySocket.close();
        }
    }

    static final String TTABLE_S2T_SERVER_PORT = "ttable_s2t_server_port";

    static final String TTABLE_T2S_SERVER_PORT = "ttable_t2s_server_port";

    private static final String LEX_TABLE_TEMPLATE = "ttable_server_template";

    private static final String GENRE = "$GENRE";

    private static final String DIRECTION = "$DIRECTION";

    private ServerSocket serverSocket;

    private Map<Byte, Map<Integer, Map<Integer, Double>>> model = new HashMap<>();

    private double minLexProb = 0;

    private Runnable server = new Runnable() {

        @Override
        public void run() {
            while (true) {
                try {
                    Socket querySocket = serverSocket.accept();
                    threadPool.execute(new QueryRunnable(querySocket));
                } catch (SocketException e) {
                    e.printStackTrace();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }

        }
    };

    private void startServer() {
        Thread serverThread = new Thread(server);
        serverThread.setDaemon(true);
        serverThread.start();
    }

    private void loadModel(String modelFile, byte prov) throws FileNotFoundException, IOException {
        try (BufferedReader br = new BufferedReader(
                new InputStreamReader(new GZIPInputStream(new FileInputStream(modelFile))))) {
            String line;
            int count = 1;
            while ((line = br.readLine()) != null) {
                if (count % 1000000 == 0) {
                    System.err.println("Processed " + count + " lines");
                }
                count++;
                line = line.replace("NULL", "0");
                String[] parts = StringUtils.split(line, '\\', ' ');
                try {
                    int sourceWord = Integer.parseInt(parts[0]);
                    int targetWord = Integer.parseInt(parts[1]);
                    double model1Probability = Double.parseDouble(parts[2]);
                    if (model1Probability < minLexProb) {
                        continue;
                    }
                    if (!model.get(prov).containsKey(sourceWord)) {
                        model.get(prov).put(sourceWord, new HashMap<Integer, Double>());
                    }
                    model.get(prov).get(sourceWord).put(targetWord, model1Probability);
                } catch (NumberFormatException e) {
                    System.out.println("Unable to parse line: " + e.getMessage() + "\n" + line);
                }
            }
        }
    }

    private void setup(Configuration conf, String direction, boolean source2Target)
            throws IOException, InterruptedException {
        int serverPort;
        if (source2Target) {
            serverPort = Integer.parseInt(conf.get(TTABLE_S2T_SERVER_PORT));
        } else {
            serverPort = Integer.parseInt(conf.get(TTABLE_T2S_SERVER_PORT));
        }
        minLexProb = Double.parseDouble(conf.get("min_lex_prob"));
        serverSocket = new ServerSocket(serverPort);
        String lexTemplate = conf.get(LEX_TABLE_TEMPLATE);
        String allString = lexTemplate.replace(GENRE, "ALL").replace(DIRECTION, direction);
        System.out.println("Loading " + allString);
        String[] provenances = conf.getStrings(ProvenanceCountMap.PROV);
        ExecutorService loaderThreadPool = Executors.newFixedThreadPool(4);
        model.put((byte) 0, new HashMap<Integer, Map<Integer, Double>>());
        loaderThreadPool.execute(new LoadTask(allString, (byte) 0));
        for (int i = 0; i < provenances.length; ++i) {
            String provString = lexTemplate.replace(GENRE, provenances[i]).replace(DIRECTION, direction);
            System.out.println("Loading " + provString);
            byte prov = (byte) (i + 1);
            model.put(prov, new HashMap<Integer, Map<Integer, Double>>());
            loaderThreadPool.execute(new LoadTask(provString, prov));
        }
        loaderThreadPool.shutdown();
        loaderThreadPool.awaitTermination(3, TimeUnit.HOURS);
        System.gc();
    }

    @Override
    public void close() throws IOException {
        threadPool.shutdown();
    }

    /**
     * Defines command line args.
     */
    @Parameters(separators = "=")
    public static class TTableServerParameters {
        @Parameter(names = { "--ttable_s2t_server_port" }, description = "TTable source-to-target server port")
        public String ttable_s2t_server_port = "4949";

        @Parameter(names = { "--ttable_s2t_host" }, description = "TTable source-to-target host name")
        public String ttable_s2t_host = "localhost";

        @Parameter(names = { "--ttable_t2s_server_port" }, description = "TTable target-to-source server port")
        public String ttable_t2s_server_port = "9494";

        @Parameter(names = { "--ttable_t2s_host" }, description = "TTable target-to-source host name")
        public String ttable_t2s_host = "localhost";

        @Parameter(names = {
                "--ttable_server_template" }, description = "TTable target-to-source host name", required = true)
        public String ttable_server_template;

        @Parameter(names = {
                "--ttable_direction" }, description = "TTable direction for the lexical model ('s2t' or 't2s')", required = true)
        public String ttable_direction;

        @Parameter(names = {
                "--ttable_language_pair" }, description = "TTable language pair for the lexical model (e.g. 'en2ru' or 'ru2en')", required = true)
        public String ttable_language_pair;

        @Parameter(names = { "--provenance" }, description = "Comma-separated list of provenances", required = true)
        public String provenance;

        @Parameter(names = {
                "--min_lex_prob" }, description = "Minimum probability for a Model 1 entry. Entries with lower probability are discarded.")
        public String min_lex_prob = "0";
    }

    public int run(String[] args)
            throws IllegalArgumentException, IllegalAccessException, IOException, InterruptedException {
        TTableServerParameters params = new TTableServerParameters();
        JCommander cmd = new JCommander(params);

        try {
            cmd.parse(args);
            Configuration conf = getConf();
            Util.ApplyConf(cmd, "", conf);
            boolean source2Target;
            if (params.ttable_direction.equals("s2t")) {
                source2Target = true;
            } else if (params.ttable_direction.equals("t2s")) {
                source2Target = false;
            } else {
                throw new RuntimeException("Unknown direction: " + args[2]);
            }
            try (TTableServer server = new TTableServer()) {
                server.setup(conf, params.ttable_language_pair, source2Target);
                server.startServer();
                System.err.println("TTable server ready on port: " + server.serverSocket.getLocalPort());
                Thread.sleep(24 * 60 * 60 * 1000); // Sleep for 24 hours
            }
        } catch (ParameterException e) {
            System.err.println(e.getMessage());
            cmd.usage();
        }

        return 1;
    }

    public static void main(String[] args) throws Exception {
        int res = ToolRunner.run(new TTableServer(), args);
        System.exit(res);
    }
}